-
Notifications
You must be signed in to change notification settings - Fork 0
/
common.py
96 lines (80 loc) · 3 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Other common utilities for filepaths and device info."""
import os
import tensorflow as tf
__all__ = ["config_gpu", "get_save_path"]
def config_gpu():
"""Configure TensorFlow GPU settings."""
gpus = tf.config.list_physical_devices("GPU")
print(f"Num GPUs Available: {len(gpus)}")
if gpus:
# Restrict TensorFlow to only use the first GPU
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
tf.config.set_visible_devices(gpus[0], "GPU")
logical_gpus = tf.config.list_logical_devices("GPU")
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
except RuntimeError as e:
# Visible devices must be set before GPUs have been initialized
print(e)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
else:
raise Exception("No GPUs available. Unable to run model.")
print()
def get_configs(configs_path: str, is_evaluating: bool, restart_training: bool):
"""Return Config object."""
from utils.config import configs # pylint: disable=import-outside-toplevel
print(f"==> loading configs from {configs_path}")
configs.update_from_modules(configs_path)
# define save path
configs.train.save_path = get_save_path(configs_path, prefix="runs")
# override configs with args
configs.eval.is_evaluating = is_evaluating
configs.train.restart_training = restart_training
assert (
not configs.train.restart_training or not configs.eval.is_evaluating
), "Cannot set '--restart' and '--eval' flag at the same time."
save_path = configs.train.save_path
configs.train.train_ckpts_path = os.path.join(save_path, "training_ckpts")
configs.train.best_ckpt_path = os.path.join(save_path, "best_ckpt")
if configs.eval.is_evaluating:
batch_size = configs.eval.batch_size
else:
batch_size = configs.train.batch_size
if configs.train.restart_training:
os.makedirs(configs.train.train_ckpts_path, exist_ok=False)
os.makedirs(configs.train.best_ckpt_path, exist_ok=False)
else:
assert os.path.exists(
configs.train.train_ckpts_path
), f"Training without '--restart' flag set but {configs.train.train_ckpts_path} path does not exist."
assert os.path.exists(
configs.train.best_ckpt_path
), f"Training without '--restart' flag set but {configs.train.best_ckpt_path} path does not exist."
configs.dataset.batch_size = batch_size
return configs
def get_save_path(*configs, prefix: str = "runs") -> str:
"""Get string path to save model checkpoints."""
memo = {}
for c in configs:
cmemo = memo
c = c.replace("configs/", "").replace(".py", "").split("/")
for m in c:
if m not in cmemo:
cmemo[m] = dict()
cmemo = cmemo[m]
def get_str(m, p):
n = len(m)
if n > 1:
p += "["
for i, (k, v) in enumerate(m.items()):
p += k
if len(v) > 0:
p += "."
p = get_str(v, p)
if n > 1 and i < n - 1:
p += "+"
if n > 1:
p += "]"
return p
return os.path.join(prefix, get_str(memo, ""))