-
Notifications
You must be signed in to change notification settings - Fork 35
/
main.py
105 lines (67 loc) · 2.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
GPU_index = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_index
import logging
import torch
import numpy as np
from train import Trainer
from evaluate import Evaluator
from shutil import copytree, ignore_patterns
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.utils_common import DataModes
import wandb
from IPython import embed
from utils.utils_common import mkdir
from config import load_config
from model.voxel2mesh import Voxel2Mesh as network
logger = logging.getLogger(__name__)
def init(cfg):
save_path = cfg.save_path + cfg.save_dir_prefix + str(cfg.experiment_idx).zfill(3)
mkdir(save_path)
trial_id = (len([dir for dir in os.listdir(save_path) if 'trial' in dir]) + 1) if cfg.trial_id is None else cfg.trial_id
trial_save_path = save_path + '/trial_' + str(trial_id)
if not os.path.isdir(trial_save_path):
mkdir(trial_save_path)
copytree(os.getcwd(), trial_save_path + '/source_code', ignore=ignore_patterns('*.git','*.txt','*.tif', '*.pkl', '*.off', '*.so', '*.json','*.jsonl','*.log','*.patch','*.yaml','wandb','run-*'))
seed = trial_id
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.enabled = True # speeds up the computation
return trial_save_path, trial_id
def main():
exp_id = 3
# Initialize
cfg = load_config(exp_id)
trial_path, trial_id = init(cfg)
print('Experiment ID: {}, Trial ID: {}'.format(cfg.experiment_idx, trial_id))
print("Create network")
classifier = network(cfg)
classifier.cuda()
wandb.init(name='Experiment_{}/trial_{}'.format(cfg.experiment_idx, trial_id), project="vm-net", dir=trial_path)
print("Initialize optimizer")
optimizer = optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=cfg.learning_rate)
print("Load pre-processed data")
data_obj = cfg.data_obj
data = data_obj.quick_load_data(cfg, trial_id)
loader = DataLoader(data[DataModes.TRAINING], batch_size=classifier.config.batch_size, shuffle=True)
print("Trainset length: {}".format(loader.__len__()))
print("Initialize evaluator")
evaluator = Evaluator(classifier, optimizer, data, trial_path, cfg, data_obj)
print("Initialize trainer")
trainer = Trainer(classifier, loader, optimizer, cfg.numb_of_itrs, cfg.eval_every, trial_path, evaluator)
if cfg.trial_id is not None:
print("Loading pretrained network")
save_path = trial_path + '/best_performance/model.pth'
checkpoint = torch.load(save_path)
classifier.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
else:
epoch = 0
trainer.train(start_iteration=epoch)
# To evaluate a pretrained model, uncomment line below and comment the line above
# evaluator.evaluate(epoch)
if __name__ == "__main__":
main()