forked from miracleyoo/pytorch-lightning-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
79 lines (65 loc) · 3.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
# pytorch-lightning
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule
from opt import get_args
from utils.config import config, cfg_from_yaml_file, cfg_update_from_cli
from utils.training_utils import set_random_seed
import datetime
from utils.pytorch_lighting_utils import ImageLogCallback
from utils.utils import make_source_code_snapshot
from dataset import build_lightning_data_module
from model import build_model_module, build_experiment_module
def train(config):
data_module = build_lightning_data_module(config.data)
model_module = build_model_module(config.model)
exp_module = build_experiment_module(model_module, config)
print(f"Start with exp_name: {config.name}.")
print(os.path.join(config.log_dir, config.group))
logger = TensorBoardLogger(save_dir=os.path.join(config.log_dir, config.group), name=config.name)
checkpoint_callback = ModelCheckpoint(
dirpath=config.exp_path,
filename="{epoch:d}",
monitor="val/psnr",
mode="max",
save_top_k=5,
save_last=True,
every_n_epochs=1,
save_on_train_epoch_end=True,
)
callbacks = [LearningRateMonitor("step")]
make_source_code_snapshot(f"{config.exp_path}")
OmegaConf.save(config=config, f=os.path.join(config.exp_path, "run_config_snapshot.yaml"))
trainer = Trainer(
max_epochs=config.train.num_epochs, # 最大训练轮数
callbacks=callbacks, # 回调函数
# resume_from_checkpoint=config.ckpt_path, # 从checkpoint恢复
logger=logger, # 日志
enable_model_summary=False, # 是否打印模型结构
# gpus=config.num_gpus, # 使用的GPU数量
# accelerator="ddp" if config.num_gpus > 1 else None, # 使用的加速器
num_sanity_val_steps=1, # 验证集的batch数
benchmark=True, # 是否开启benchmark
# profiler="simple" if config.num_gpus == 1 else None, # 是否开启profiler
# val_check_interval=1, # 验证集的检查间隔
log_every_n_steps=50, # logger的间隔
precision=config.precision, # 半精度加速
)
trainer.fit(exp_module, data_module)
if __name__ == '__main__':
args = get_args() # 获取命令行参数
# 从yaml文件中加载配置
cfg_from_yaml_file(config=config, cfg_file=args.cfg_file)
cfg_update_from_cli(config=config, args=args)
config.merge_with_dotlist(args.opts)
# 设置随机种子
if args.seed is not None:
set_random_seed(args.seed)
config.name = config.name + datetime.datetime.now().strftime("%mM_%dD_%HH_%MM") + "_seed" + str(config.seed)
# 设置日志路径
config.exp_path = os.path.join(config.log_dir, config.group, config.name)
print(config.exp_path)
train(config)