-
Notifications
You must be signed in to change notification settings - Fork 77
/
jobs.py
92 lines (68 loc) · 2.93 KB
/
jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import os
import random
import sys
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import yaml
from loguru import logger
from micalib.tester import Tester
from micalib.trainer import Trainer
from utils import util
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '.')))
def setup(rank, world_size, port):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
dist.init_process_group("nccl", rank=rank, world_size=world_size, init_method="env://")
def deterministic(rank):
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
np.random.seed(rank)
random.seed(rank)
cudnn.deterministic = True
cudnn.benchmark = False
def test(rank, world_size, cfg, args):
port = np.random.randint(low=0, high=2000)
setup(rank, world_size, 12310 + port)
deterministic(rank)
cfg.model.testing = True
mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, rank)
tester = Tester(nfc_model=mica, config=cfg, device=rank)
tester.render_mesh = True
if args.test_dataset.upper() == 'STIRLING':
tester.test_stirling(args.checkpoint)
elif args.test_dataset.upper() == 'NOW':
tester.test_now(args.checkpoint)
else:
logger.error('[TESTER] Test dataset was not specified!')
dist.destroy_process_group()
def train(rank, world_size, cfg):
port = np.random.randint(low=0, high=2000)
setup(rank, world_size, 12310 + port)
logger.info(f'[MAIN] output_dir: {cfg.output_dir}')
os.makedirs(os.path.join(cfg.output_dir, cfg.train.log_dir), exist_ok=True)
os.makedirs(os.path.join(cfg.output_dir, cfg.train.vis_dir), exist_ok=True)
os.makedirs(os.path.join(cfg.output_dir, cfg.train.val_vis_dir), exist_ok=True)
with open(os.path.join(cfg.output_dir, cfg.train.log_dir, 'full_config.yaml'), 'w') as f:
yaml.dump(cfg, f, default_flow_style=False)
# shutil.copy(cfg.cfg_file, os.path.join(cfg.output_dir, 'config.yaml'))
deterministic(rank)
nfc = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, rank)
trainer = Trainer(nfc_model=nfc, config=cfg, device=rank)
trainer.fit()
dist.destroy_process_group()