-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
153 lines (127 loc) · 5.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import logging
import os
import sys
from functools import partial
from typing import Dict, List
import numpy as np
import pandas as pd
import transformers
from datasets import load_dataset
from evaluate import load
from setproctitle import setproctitle
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainer,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
# Argument
from utils import DatasetsArguments, ModelArguments, TNTTrainingArguments
from utils.metrics import TNTEvaluator, preprocess_logits_for_metrics
from utils.preprocess_func import bart_preprocess, gpt_preprocess, t5_preprocess
logger = logging.getLogger(__name__)
DATA_EXTENTION = "csv"
def main(model_args: ModelArguments, data_args: DatasetsArguments, training_args: TNTTrainingArguments):
set_seed(training_args.seed)
data_files = dict()
if data_args.train_csv_paths is not None:
data_files.update({"train": data_args.train_csv_paths})
if data_args.valid_csv_paths is not None:
data_files.update({"validation": data_args.valid_csv_paths})
dataset = load_dataset(DATA_EXTENTION, data_files=data_files, cache_dir=model_args.cache_dir)
# [TODO] valid가 없으면 train에서 스플릿해서 나누는 코드 작성
# train_csv_paths? s?
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
model_max_length = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else config.n_ctx
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=model_max_length,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
)
if "gpt" in model_args.model_name_or_path:
preprocess = partial(gpt_preprocess, tokenizer=tokenizer, train_type=training_args.train_type)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
elif "t5" in model_args.model_name_or_path:
preprocess = t5_preprocess
elif "bart" in model_args.model_name_or_path:
preprocess = partial(bart_preprocess, tokenizer=tokenizer, train_type=training_args.train_type)
dataset = dataset.map(preprocess, num_proc=data_args.num_proc, remove_columns=dataset["train"].column_names)
train_dataset = dataset["train"]
valid_dataset = dataset["validation"]
# [TODO] 완성하기
# model = AutoModelForSeq2SeqLM.from_pretrained(model_args.model_name_or_path)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# print("in")
tnt_metrics = TNTEvaluator(tokenizer=tokenizer).compute_metrics
if is_main_process(training_args.local_rank):
import wandb
wandb.init(
project=training_args.wandb_project,
entity=training_args.wandb_entity,
name=training_args.wandb_name,
)
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
trainer = Seq2SeqTrainer(
model=model,
data_collator=data_collator,
args=training_args,
# compute_metrics=tnt_metrics,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=valid_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if training_args.do_train:
# use last checkpoint if exist
if last_checkpoint is not None:
checkpoint = last_checkpoint
else:
checkpoint = None
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model(output_dir=training_args.output_dir)
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(eval_dataset=valid_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Predict ***")
test_results = trainer.predict(test_dataset=valid_dataset)
metrics = test_results.metrics
metrics["predict_samples"] = len(metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DatasetsArguments, TNTTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
)
main(model_args=model_args, data_args=data_args, training_args=training_args)