-
Notifications
You must be signed in to change notification settings - Fork 22
/
train.py
201 lines (173 loc) · 7.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
os.environ["WANDB_PROJECT"]= "lmms-ft"
from dataclasses import asdict
import math
from pathlib import Path
from typing import List, Optional
import yaml
from accelerate.utils import DistributedType
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
import transformers
from transformers import Trainer, deepspeed
from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments
from collators import COLLATORS
from datasets import LazySupervisedDataset
from loaders import LOADERS
from supported_models import MODULE_KEYWORDS
from utils import (
rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer,
get_peft_state_maybe_zero_3, TrainerWithCustomSampler
)
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()
# dumping arguments
output_dir = getattr(training_args, 'output_dir', None)
assert output_dir is not None, "output_dir is required"
args_dir = Path(output_dir) / "arguments"
args_dir.mkdir(parents=True, exist_ok=True)
yaml.dump(asdict(model_args), open(args_dir / "model.yaml", "w"))
yaml.dump(asdict(data_args), open(args_dir / "data.yaml", "w"))
yaml.dump(asdict(training_args), open(args_dir / "training.yaml", "w"))
yaml.dump(asdict(lora_args), open(args_dir / "lora.yaml", "w"))
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
device_map = None
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
raise ValueError("FSDP or ZeRO3 are not incompatible with QLoRA.")
# llm quantization config (for q-lora)
bnb_config = None
if lora_args.use_lora and lora_args.q_lora:
from transformers import BitsAndBytesConfig
rank0_print("Quantization for LLM enabled...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type="nf4",
)
# load model, tokenizer, processor
rank0_print("Loading model, tokenizer, processor...")
loader = LOADERS[model_args.model_family_id](
model_hf_path=model_args.model_hf_path,
model_local_path=model_args.model_local_path,
compute_dtype=compute_dtype,
bnb_config=bnb_config,
use_flash_attn=training_args.use_flash_attn,
device_map=device_map,
)
model, tokenizer, processor, config = loader.load()
tokenizer.model_max_length = training_args.model_max_length
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
# freeze certain params
vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"]
if not training_args.train_vision_encoder:
rank0_print(f"Vision encoder is freezed... including:")
for module in vision_encoder_keys:
rank0_print(f"\t{module}")
eval(f"model.{module}").requires_grad_(False)
vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"]
if not training_args.train_vision_projector:
rank0_print(f"Vision projector is freezed... including:")
for module in vision_projector_keys:
rank0_print(f"\t{module}")
eval(f"model.{module}").requires_grad_(False)
# other components preparation (e.g., image_newline, vision_resampler)
# we will just freeze these
if "others" in MODULE_KEYWORDS[model_args.model_family_id]:
rank0_print(f"Other multimodal component is freezed... including:")
for other_key in MODULE_KEYWORDS[model_args.model_family_id]["others"]:
rank0_print(f"\t{other_key}")
eval(f"model.{other_key}").requires_grad_(False)
# lora preparation
llm_keys = MODULE_KEYWORDS[model_args.model_family_id]["llm"]
if not (lora_args.use_lora or (training_args.train_vision_encoder and lora_args.use_vision_lora)):
rank0_print("No LoRA enabled...")
else:
named_modules = {n: m for n, m in model.named_modules()}
lora_modules = []
full_modules = []
if training_args.train_vision_encoder and lora_args.use_vision_lora:
rank0_print("LoRA for vision encoder enabled...")
lora_modules.extend(find_all_linear_names(named_modules, vision_encoder_keys))
elif training_args.train_vision_encoder:
rank0_print("Vision encoder will be fully trained...")
full_modules.extend(vision_encoder_keys)
if lora_args.use_lora:
rank0_print("LoRA for LLM enabled...")
lora_modules.extend(find_all_linear_names(named_modules, llm_keys))
else:
rank0_print("LLM will be fully trained...")
full_modules.extend(llm_keys)
if training_args.train_vision_projector:
rank0_print("Vision projector will be fully trained...")
full_modules.extend(vision_projector_keys)
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_modules,
modules_to_save=full_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM",
)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, lora_config)
# print trainable parameters for inspection
rank0_print("Trainable parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
rank0_print(f"\t{name}")
# load data
rank0_print("Loading data...")
train_dataset = LazySupervisedDataset(
data_path=data_args.data_path,
image_folder=data_args.image_folder,
video_folder=data_args.video_folder,
num_frames=data_args.num_frames,
model_family_id=model_args.model_family_id,
user_key=data_args.user_key,
assistant_key=data_args.assistant_key
)
if data_args.eval_data_path:
eval_dataset = LazySupervisedDataset(
data_path=data_args.eval_data_path,
image_folder=data_args.image_folder,
video_folder=data_args.video_folder,
num_frames=data_args.num_frames,
model_family_id=model_args.model_family_id,
user_key=data_args.user_key,
assistant_key=data_args.assistant_key
)
else:
eval_dataset = None
training_args.eval_strategy = "no"
# data collator
data_collator = COLLATORS[model_args.model_family_id](
config=config,
tokenizer=tokenizer,
processor=processor,
mask_question_tokens=training_args.mask_question_tokens
)
# trainer
trainer = TrainerWithCustomSampler(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir)
if __name__ == "__main__":
train()