-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsweep.py
182 lines (148 loc) · 6.43 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import sys
import pickle as pickle
import pytz
from datetime import datetime
import wandb
import torch
from transformers import (
AutoTokenizer,
EarlyStoppingCallback,
TrainingArguments,
)
from argparse import Namespace
from utils.args import *
from load_data.load_data import *
from model.model import *
from model.metric import *
from trainer.trainer import *
from utils.utils import *
from typing import Any
def main(config: Namespace) -> None:
"""
Sweep 초기화 및 Wandb sweep agent 선언
Args:
config(Namespace): 모델 학습에 필요한 hyperparameter를 포함하는 딕셔너리
Returns:
None
"""
def sweep_train(config: Namespace = config) -> None:
"""
Sweep agent 선언시 function에 전달되는 함수
Args:
config(Namespace): 모델 학습에 필요한 hyperparmeter를 포함하는 딕셔너리
Returns:
None
"""
wandb.init(
entity=config.wandb['entity'],
project=config.wandb['sweep_project_name']
)
sweep_config = wandb.config
seed_everything(config.seed)
# load model and tokenizer
model_name = config.model['name']
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 1. load dataset
# 2. preprocess dataset
# 3. tokenize dataset
revision = config.dataloader['revision']
input_format = sweep_config['input_format']
prompt = sweep_config['prompt']
type_transform = sweep_config['type_transform']
train_dataset, train_raw_label = load_train_dataset(
split=config.dataloader['train_split'],
revision=revision,
tokenizer=tokenizer,
input_format=input_format,
prompt=prompt,
type_transform=type_transform,
)
dev_dataset, dev_raw_label = load_train_dataset(
split=config.dataloader['valid_split'],
revision=revision,
tokenizer=tokenizer,
input_format=input_format,
prompt=prompt,
type_transform=type_transform,
)
train_label = label_to_num(train_raw_label)
dev_label = label_to_num(dev_raw_label)
# 4. make Dataset object
re_train_dataset = REDataset(train_dataset, train_label)
re_dev_dataset = REDataset(dev_dataset, dev_label)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
# 5. import model
# setting model hyperparameter
model_module = __import__('model.model', fromlist=[config.model['variant']])
model_class = getattr(model_module, config.model['variant'])
# Available customized classes:
# BaseREModel, BiLSTMREModel, BiGRUREModel
model = model_class(config, len(tokenizer))
print(model.model_config)
model.parameters
model.to(device)
# 6. training arguments 설정
## 사용한 option 외에도 다양한 option들이 있습니다.
## https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments 참고해주세요.
training_args = TrainingArguments(
# 기본 설정
output_dir=config.trainer['output_dir'], # 모델 저장 디렉토리
report_to=('wandb' if config.use_wandb else 'none'), # wandb 사용 여부
fp16=True, # 16-bit floating point precision
# 학습 설정
num_train_epochs=sweep_config['epochs'], # 전체 훈련 epoch 수
learning_rate=sweep_config['lr'], # learning rate
weight_decay=config.optimizer['weight_decay'], # weight decay
adam_beta2=sweep_config['adam_beta2'], # AdamW 옵티마이저의 beta2 하이퍼파라미터
# 배치 사이즈 설정
per_device_train_batch_size=config.dataloader['batch_size'], # 훈련 중 장치 당 batch size
per_device_eval_batch_size=config.dataloader['batch_size'], # 평가 중 장치 당 batch size
# 스케줄링 설정
warmup_ratio=sweep_config['warmup_ratio'], # learning rate scheduler의 warmup 비율
# warmup_steps=config.lr_scheduler['warmup_steps'], # number of warmup steps for learning rate scheduler
# 로깅 설정
logging_dir=config.trainer['logging_dir'], # 로그 저장 디렉토리
logging_steps=config.trainer['logging_steps'], # 로그 저장 스텝
# 모델 저장 설정
save_total_limit=config.trainer['save_total_limit'], # 전체 저장 모델 수 제한
save_steps=config.trainer['save_steps'], # 모델 저장 스텝
save_strategy=config.trainer['save_strategy'],
# 평가 설정
evaluation_strategy=config.trainer['evaluation_strategy'], # 훈련 중 평가 전략
eval_steps=config.trainer['evaluation_steps'], # 평가 스텝
load_best_model_at_end=True,
)
# 7. trainer 설정
# 8. evaluate 함수 설정
trainer = RETrainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=re_train_dataset, # training dataset
eval_dataset=re_dev_dataset, # evaluation dataset
compute_metrics=compute_metrics, # define metrics function
# callbacks=([WandbCallback()] if config.use_wandb else []),
# callbacks=[EarlyStoppingCallback(early_stopping_patience=config.trainer['early_stop'])],
loss_cfg=config.loss,
)
# 9. train model
trainer.train()
# 10. save model
trainer.save_model(config.trainer['model_dir'])
sweep_id = wandb.sweep(
sweep=config.sweep_config
)
wandb.agent(
sweep_id=sweep_id,
function=sweep_train,
count=config.wandb['sweep_count']
)
if __name__ == '__main__':
try:
config_path = sys.argv[1]
except IndexError:
config_path = './config.yaml'
config = parse_arguments(config_path)
now = datetime.now(pytz.timezone('Asia/Seoul'))
run_name = f'{config.run_name}_{now.strftime("%d-%H-%M")}'
main(config)