-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
120 lines (98 loc) · 4.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from dataset.collective_audio_segment import SpeechDataModule
from model.sylber import SylberTrainer
import lightning as pl
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import hydra
import torch
from collections import OrderedDict
from weakref import proxy
torch.set_float32_matmul_precision('medium')
class ModelCheckpointWithEMA(ModelCheckpoint):
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)
model = trainer.lightning_module.net
ema_dict= OrderedDict()
ema_dict['ema']=model.ema.model.state_dict()
if model.lm_ema is not None:
ema_dict['lm_ema']=model.lm_ema.model.state_dict()
if model.input_ema is not None:
ema_dict['input_ema']=model.input_ema.model.state_dict()
if model.logit_ema is not None:
ema_dict['logit_ema']=model.logit_ema.model.state_dict()
torch.save(ema_dict,"ema_dict.ckpt")
self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = filepath
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))
@hydra.main(config_path='sylber_configs', config_name='sylber_base')
def main(cfg):
print(cfg)
# datamodule
datamodule = SpeechDataModule(**cfg['data'])
# model
model = SylberTrainer(**cfg['model'])
if 'speech_model_ckpt' in cfg.keys() and cfg['speech_model_ckpt'] != None:
state_dict = torch.load(cfg['speech_model_ckpt'], map_location='cpu')
model.net.speech_model.load_state_dict(state_dict, strict=False )
print("Pre-trained checkpoint loaded")
if 'model_ckpt' in cfg.keys() and cfg['model_ckpt'] != None:
state_dict = torch.load(cfg['model_ckpt'], map_location='cpu')['state_dict']
try:
model.load_state_dict(state_dict, strict=False)
except:
print("Can't load LM. Removing the weights.")
new_dict= OrderedDict()
for name, state in state_dict.items():
if 'net.language_model' not in name and 'net.logit' not in name and 'net.input_linear' not in name :
new_dict[name] = state
model.load_state_dict(new_dict, strict=False)
print("Previous stage checkpoint loaded")
# Callbacks
lr_monitor = LearningRateMonitor(logging_interval='step')
'''
# checkpoint best
checkpoint_callback_topk = ModelCheckpoint(
monitor="val_loss",
save_top_k=1,
mode="min",
filename='best-{epoch}-{val_loss:.2f}'
)
'''
# checkpoint every N epochs
checkpoint_callback_by_epoch = ModelCheckpointWithEMA(
every_n_epochs=cfg['checkpoint_epoch'],
)
checkpoint_callback_last5 = ModelCheckpoint(save_top_k=5, mode='max', monitor='epoch')
# Trainer
if cfg['gpus'] is not None:
if not isinstance(cfg['gpus'],list):
try:
gpus = [int(cfg['gpus'])]
except:
gpus = [int(x) for x in cfg['gpus'].split(',')]
else:
gpus = cfg['gpus']
else:
gpus= None
callbacks = [checkpoint_callback_last5,
checkpoint_callback_by_epoch,
LearningRateMonitor(logging_interval='step')]
scaler = torch.cuda.amp.GradScaler()
trainer = pl.Trainer(devices=gpus,
accelerator="gpu",
strategy="ddp_find_unused_parameters_true",
max_steps = cfg['max_steps'],
num_sanity_val_steps=0,
check_val_every_n_epoch=cfg['check_val_every_n_epoch'],
limit_val_batches=cfg['limit_val_batches'],
callbacks=callbacks,
gradient_clip_val=0.5,
default_root_dir=cfg.get('name', 'noname'),
accumulate_grad_batches=cfg['accumulate_grad_batches'],
)
# fit model
trainer.fit(model,datamodule,ckpt_path=cfg['resume_ckpt'],)
if __name__ =='__main__':
main()