-
Notifications
You must be signed in to change notification settings - Fork 35
/
training.py
122 lines (96 loc) · 3.45 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import os
from pathlib import Path
import hydra
import wandb
from datasets import load_dataset
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
from transformers import TrainingArguments
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
logger = logging.getLogger(__name__)
def setup_wandb(args: dict):
"""
WANDB integration for tracking training runs.
"""
env = {key: os.getenv(key) for key in os.environ}
run = wandb.init(
job_type="train",
project=args["experiment"],
entity=args["wandb_entity"],
config={**args, **env},
tags=["train"],
)
return run
@hydra.main(version_base=None, config_path="./configs", config_name="training")
def main(args):
logger.info(OmegaConf.to_yaml(args))
OmegaConf.set_struct(args, False)
logger.info(f"Experiment name: {args.experiment}")
logger.info(f"Output path: {args.train.output_dir}")
if args.use_wandb:
run = setup_wandb(OmegaConf.to_container(args))
logger.info(f"Loading dataset: {args.data_file}")
dataset = load_dataset(
"json", data_files=to_absolute_path(args.data_file), split="train"
)
logger.info(f"Loading instruction from file {args.instruction}...")
instruction = open(args.instruction).read()
logger.info(f"Loaded instruction: {instruction}")
if args.shuffle:
dataset = dataset.shuffle(seed=args.shuffle)
if args.limit:
dataset = dataset.select(range(min(args.limit, len(dataset))))
model_class = hydra.utils.instantiate(args.model, _convert_="object")
logger.info("Model was loaded.")
def format_answer(example):
query = example[args.input_key]
if args.model.instruction_in_prompt:
query = instruction + "\n" + query
output = (
out[0] if isinstance(out := example[args.output_key], list) else out
) or ""
if args.template:
return open(args.template).read().format(query=query, output=output)
else:
messages = [
{
"role": "system",
"content": instruction,
},
{"role": "user", "content": query},
{
"role": "assistant",
"content": output,
},
]
return dict(messages=messages)
dataset = dataset.map(format_answer)
collator = DataCollatorForCompletionOnlyLM(
model_class.tokenizer.encode(
args.model.completion_start, add_special_tokens=False
),
tokenizer=model_class.tokenizer,
)
logger.info("Initializing training arguments...")
training_args = TrainingArguments(**args.train)
logger.info("Starting to train...")
trainer = SFTTrainer(
model=model_class.model,
args=training_args,
data_collator=collator,
train_dataset=dataset,
dataset_batch_size=1,
packing=False,
max_seq_length=args.model.max_sequence_len,
dataset_kwargs=dict(add_special_tokens=False),
)
trainer.train(resume_from_checkpoint=args.resume_checkpoint)
logger.info(
f"Finished training; saving model to {args.train.output_dir}/checkpoint..."
)
trainer.model.save_pretrained(Path(args.train.output_dir) / "checkpoint/")
if args.hfhub_tag:
trainer.model.push_to_hub(args.hfhub_tag, private=True)
if __name__ == "__main__":
main()