-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathconfig.py
108 lines (86 loc) · 2.46 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import division
from __future__ import print_function
import os.path as osp
import numpy as np
from easydict import EasyDict as edict
__C = edict()
cfg = __C
# Dataset name: flowers, birds
__C.DATASET_NAME = "birds"
__C.CONFIG_NAME = ""
__C.DATA_DIR = ""
__C.GPU_ID = 0
__C.CUDA = True
__C.WORKERS = 6
__C.RNN_TYPE = "LSTM" # 'GRU'
__C.B_VALIDATION = False
__C.TREE = edict()
__C.TREE.BRANCH_NUM = 3
__C.TREE.BASE_SIZE = 64
# Training options
__C.TRAIN = edict()
__C.TRAIN.BATCH_SIZE = 64
__C.TRAIN.MAX_EPOCH = 600
__C.TRAIN.SNAPSHOT_INTERVAL = 2000
__C.TRAIN.DISCRIMINATOR_LR = 2e-4
__C.TRAIN.GENERATOR_LR = 2e-4
__C.TRAIN.ENCODER_LR = 2e-4
__C.TRAIN.RNN_GRAD_CLIP = 0.25
__C.TRAIN.FLAG = True
__C.TRAIN.NET_E = ""
__C.TRAIN.NET_G = ""
__C.TRAIN.B_NET_D = True
__C.TRAIN.SMOOTH = edict()
__C.TRAIN.SMOOTH.GAMMA1 = 5.0
__C.TRAIN.SMOOTH.GAMMA3 = 10.0
__C.TRAIN.SMOOTH.GAMMA2 = 5.0
__C.TRAIN.SMOOTH.LAMBDA = 1.0
# Modal options
__C.GAN = edict()
__C.GAN.DF_DIM = 64
__C.GAN.GF_DIM = 128
__C.GAN.Z_DIM = 100
__C.GAN.CONDITION_DIM = 100
__C.GAN.R_NUM = 2
__C.GAN.B_ATTENTION = True
__C.GAN.B_DCGAN = False
__C.TEXT = edict()
__C.TEXT.CAPTIONS_PER_IMAGE = 10
__C.TEXT.EMBEDDING_DIM = 256
__C.TEXT.WORDS_NUM = 18
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
if type(a) is not edict:
return
for k, v in a.items():
# a must specify keys that are in b
if k not in b:
raise KeyError("{} is not a valid config key".format(k))
# the types must match, too
old_type = type(b[k])
if old_type is not type(v):
if isinstance(b[k], np.ndarray):
v = np.array(v, dtype=b[k].dtype)
else:
raise ValueError(
("Type mismatch ({} vs. {}) " "for config key: {}").format(
type(b[k]), type(v), k
)
)
# recursively merge dicts
if type(v) is edict:
try:
_merge_a_into_b(a[k], b[k])
except:
print("Error under config key: {}".format(k))
raise
else:
b[k] = v
def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, "r") as f:
yaml_cfg = edict(yaml.load(f))
_merge_a_into_b(yaml_cfg, __C)