forked from miracleyoo/pytorch-lightning-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
132 lines (107 loc) · 4.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2021 Zhongyang Zhang
# Contact: [email protected]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This main entrance of the whole project.
Most of the code should not be changed, please directly
add all the input arguments of your model's constructor
and the dataset file's constructor. The MInterface and
DInterface can be seen as transparent to all your args.
"""
import os
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import TensorBoardLogger
from model import MInterface
from data import DInterface
from utils import load_model_path_by_args
def load_callbacks():
callbacks = []
callbacks.append(plc.EarlyStopping(
monitor='val_acc',
mode='max',
patience=10,
min_delta=0.001
))
callbacks.append(plc.ModelCheckpoint(
monitor='val_acc',
filename='best-{epoch:02d}-{val_acc:.3f}',
save_top_k=1,
mode='max',
save_last=True
))
if args.lr_scheduler:
callbacks.append(plc.LearningRateMonitor(
logging_interval='epoch'))
return callbacks
def main(args):
pl.seed_everything(args.seed)
load_path = load_model_path_by_args(args)
data_module = DInterface(**vars(args))
if load_path is None:
model = MInterface(**vars(args))
else:
model = MInterface(**vars(args))
args.resume_from_checkpoint = load_path
# # If you want to change the logger's saving folder
# logger = TensorBoardLogger(save_dir='kfold_log', name=args.log_dir)
# args.callbacks = load_callbacks()
# args.logger = logger
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, data_module)
if __name__ == '__main__':
parser = ArgumentParser()
# Basic Training Control
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--seed', default=1234, type=int)
parser.add_argument('--lr', default=1e-3, type=float)
# LR Scheduler
parser.add_argument('--lr_scheduler', choices=['step', 'cosine'], type=str)
parser.add_argument('--lr_decay_steps', default=20, type=int)
parser.add_argument('--lr_decay_rate', default=0.5, type=float)
parser.add_argument('--lr_decay_min_lr', default=1e-5, type=float)
# Restart Control
parser.add_argument('--load_best', action='store_true')
parser.add_argument('--load_dir', default=None, type=str)
parser.add_argument('--load_ver', default=None, type=str)
parser.add_argument('--load_v_num', default=None, type=int)
# Training Info
parser.add_argument('--dataset', default='standard_data', type=str)
parser.add_argument('--data_dir', default='ref/data', type=str)
parser.add_argument('--model_name', default='standard_net', type=str)
parser.add_argument('--loss', default='bce', type=str)
parser.add_argument('--weight_decay', default=1e-5, type=float)
parser.add_argument('--no_augment', action='store_true')
parser.add_argument('--log_dir', default='lightning_logs', type=str)
# Model Hyperparameters
parser.add_argument('--hid', default=64, type=int)
parser.add_argument('--block_num', default=8, type=int)
parser.add_argument('--in_channel', default=3, type=int)
parser.add_argument('--layer_num', default=5, type=int)
# Other
parser.add_argument('--aug_prob', default=0.5, type=float)
# Add pytorch lightning's args to parser as a group.
parser = Trainer.add_argparse_args(parser)
## Deprecated, old version
# parser = Trainer.add_argparse_args(
# parser.add_argument_group(title="pl.Trainer args"))
# Reset Some Default Trainer Arguments' Default Values
parser.set_defaults(max_epochs=100)
args = parser.parse_args()
# List Arguments
args.mean_sen = [0.485, 0.456, 0.406]
args.std_sen = [0.229, 0.224, 0.225]
main(args)