Skip to content

Commit

Permalink
Make several functions that construct CfgNode be classmethod, so that…
Browse files Browse the repository at this point in the history
… they are polymorphic and can be used to construct a subclass of CfgNode.
  • Loading branch information
ppwwyyxx authored and rbgirshick committed Mar 4, 2019
1 parent 647a493 commit 3ba3ecf
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 113 deletions.
230 changes: 127 additions & 103 deletions yacs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
for k, v in init_dict.items():
if type(v) is dict:
# Convert dict to CfgNode
init_dict[k] = CfgNode(v, key_list=key_list + [k])
else:
# Check for valid leaf type or nested CfgNode
_assert_with_logging(
_valid_type(v, allow_cfg_node=True),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list + [k]), type(v), _VALID_TYPES
),
)
init_dict = self._create_config_tree_from_dict(init_dict, key_list)
super(CfgNode, self).__init__(init_dict)
# Manage if the CfgNode is frozen or not
self.__dict__[CfgNode.IMMUTABLE] = False
Expand All @@ -119,6 +108,32 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Allow new attributes after initialisation
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed

@classmethod
def _create_config_tree_from_dict(cls, dic, key_list):
"""
Create a configuration tree using the given dict.
Any dict-like objects inside dict will be treated as a new CfgNode.
Args:
dic (dict):
key_list (list[str]): a list of names which index this CfgNode from the root.
Currently only used for logging purposes.
"""
dic = copy.deepcopy(dic)
for k, v in dic.items():
if isinstance(v, dict):
# Convert dict to CfgNode
dic[k] = cls(v, key_list=key_list + [k])
else:
# Check for valid leaf type or nested CfgNode
_assert_with_logging(
_valid_type(v, allow_cfg_node=False),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list + [k]), type(v), _VALID_TYPES
),
)
return dic

def __getattr__(self, name):
if name in self:
return self[name]
Expand Down Expand Up @@ -194,7 +209,7 @@ def convert_to_dict(cfg_node, key_list):
def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it this CfgNode."""
with open(cfg_filename, "r") as f:
cfg = load_cfg(f)
cfg = self.load_cfg(f)
self.merge_from_other_cfg(cfg)

def merge_from_other_cfg(self, cfg_other):
Expand Down Expand Up @@ -226,7 +241,7 @@ def merge_from_list(self, cfg_list):
d = d[subkey]
subkey = key_list[-1]
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
value = _decode_cfg_value(v)
value = self._decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
d[subkey] = value

Expand Down Expand Up @@ -310,66 +325,108 @@ def raise_key_rename_error(self, full_key):
def is_new_allowed(self):
return self.__dict__[CfgNode.NEW_ALLOWED]


def load_cfg(cfg_file_obj_or_str):
"""Load a cfg. Supports loading from:
- A file object backed by a YAML file
- A file object backed by a Python source file that exports an attribute
"cfg" that is either a dict or a CfgNode
- A string that can be parsed as valid YAML
"""
_assert_with_logging(
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
"Expected first argument to be of type {} or {}, but it was {}".format(
_FILE_TYPES, str, type(cfg_file_obj_or_str)
),
)
if isinstance(cfg_file_obj_or_str, str):
return _load_cfg_from_yaml_str(cfg_file_obj_or_str)
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
return _load_cfg_from_file(cfg_file_obj_or_str)
else:
raise NotImplementedError("Impossible to reach here (unless there's a bug)")


def _load_cfg_from_file(file_obj):
"""Load a config from a YAML file or a Python source file."""
_, file_extension = os.path.splitext(file_obj.name)
if file_extension in _YAML_EXTS:
return _load_cfg_from_yaml_str(file_obj.read())
elif file_extension in _PY_EXTS:
return _load_cfg_py_source(file_obj.name)
else:
raise Exception(
"Attempt to load from an unsupported file type {}; "
"only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
@classmethod
def load_cfg(cls, cfg_file_obj_or_str):
"""
Load a cfg.
Args:
cfg_file_obj_or_str (str or file):
Supports loading from:
- A file object backed by a YAML file
- A file object backed by a Python source file that exports an attribute
"cfg" that is either a dict or a CfgNode
- A string that can be parsed as valid YAML
"""
_assert_with_logging(
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
"Expected first argument to be of type {} or {}, but it was {}".format(
_FILE_TYPES, str, type(cfg_file_obj_or_str)
),
)
if isinstance(cfg_file_obj_or_str, str):
return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
return cls._load_cfg_from_file(cfg_file_obj_or_str)
else:
raise NotImplementedError("Impossible to reach here (unless there's a bug)")

@classmethod
def _load_cfg_from_file(cls, file_obj):
"""Load a config from a YAML file or a Python source file."""
_, file_extension = os.path.splitext(file_obj.name)
if file_extension in _YAML_EXTS:
return cls._load_cfg_from_yaml_str(file_obj.read())
elif file_extension in _PY_EXTS:
return cls._load_cfg_py_source(file_obj.name)
else:
raise Exception(
"Attempt to load from an unsupported file type {}; "
"only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
)

@classmethod
def _load_cfg_from_yaml_str(cls, str_obj):
"""Load a config from a YAML string encoding."""
cfg_as_dict = yaml.safe_load(str_obj)
return cls(cfg_as_dict)

def _load_cfg_from_yaml_str(str_obj):
"""Load a config from a YAML string encoding."""
cfg_as_dict = yaml.safe_load(str_obj)
return CfgNode(cfg_as_dict)
@classmethod
def _load_cfg_py_source(cls, filename):
"""Load a config from a Python source file."""
module = _load_module_from_file("yacs.config.override", filename)
_assert_with_logging(
hasattr(module, "cfg"),
"Python module from file {} must have 'cfg' attr".format(filename),
)
VALID_ATTR_TYPES = {dict, CfgNode}
_assert_with_logging(
type(module.cfg) in VALID_ATTR_TYPES,
"Imported module 'cfg' attr must be in {} but is {} instead".format(
VALID_ATTR_TYPES, type(module.cfg)
),
)
return cls(module.cfg)

@classmethod
def _decode_cfg_value(cls, value):
"""
Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
def _load_cfg_py_source(filename):
"""Load a config from a Python source file."""
module = _load_module_from_file("yacs.config.override", filename)
_assert_with_logging(
hasattr(module, "cfg"),
"Python module from file {} must have 'cfg' attr".format(filename),
)
VALID_ATTR_TYPES = {dict, CfgNode}
_assert_with_logging(
type(module.cfg) in VALID_ATTR_TYPES,
"Imported module 'cfg' attr must be in {} but is {} instead".format(
VALID_ATTR_TYPES, type(module.cfg)
),
)
if type(module.cfg) is dict:
return CfgNode(module.cfg)
else:
return module.cfg
If the value is a dict, it will be interpreted as a new CfgNode.
If the value is a str, it will be evaluated as literals.
Otherwise it is returned as-is.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to CfgNode objects
if isinstance(value, dict):
return cls(value)
# All remaining processing is only applied to strings
if not isinstance(value, str):
return value
# Try to interpret `value` as a:
# string, number, tuple, list, dict, boolean, or None
try:
value = literal_eval(value)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return value


load_cfg = CfgNode.load_cfg # keep this function in global scope for backward compatibility


def _valid_type(value, allow_cfg_node=False):
Expand All @@ -393,7 +450,7 @@ def _merge_a_into_b(a, b, root, key_list):
full_key = ".".join(key_list + [k])

v = copy.deepcopy(v_)
v = _decode_cfg_value(v)
v = b._decode_cfg_value(v)

if k in b:
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
Expand All @@ -416,39 +473,6 @@ def _merge_a_into_b(a, b, root, key_list):
raise KeyError("Non-existent config key: {}".format(full_key))


def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to CfgNode objects
if isinstance(v, dict):
return CfgNode(v)
# All remaining processing is only applied to strings
if not isinstance(v, str):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v


def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type. The type is correct if it matches exactly or is one of a few
Expand Down
66 changes: 56 additions & 10 deletions yacs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,30 @@
PY2 = False


def get_cfg():
cfg = CN()
class SubCN(CN):
pass


def get_cfg(cls=CN):
cfg = cls()

cfg.NUM_GPUS = 8

cfg.TRAIN = CN()
cfg.TRAIN = cls()
cfg.TRAIN.HYPERPARAMETER_1 = 0.1
cfg.TRAIN.SCALES = (2, 4, 8, 16)

cfg.MODEL = CN()
cfg.MODEL = cls()
cfg.MODEL.TYPE = "a_foo_model"

# Some extra stuff to test CfgNode.__str__
cfg.STR = CN()
cfg.STR = cls()
cfg.STR.KEY1 = 1
cfg.STR.KEY2 = 2
cfg.STR.FOO = CN()
cfg.STR.FOO = cls()
cfg.STR.FOO.KEY1 = 1
cfg.STR.FOO.KEY2 = 2
cfg.STR.FOO.BAR = CN()
cfg.STR.FOO.BAR = cls()
cfg.STR.FOO.BAR.KEY1 = 1
cfg.STR.FOO.BAR.KEY2 = 2

Expand All @@ -44,9 +48,9 @@ def get_cfg():
message="Please update your config fil config file.",
)

cfg.KWARGS = CN(new_allowed=True)
cfg.KWARGS = cls(new_allowed=True)
cfg.KWARGS.z = 0
cfg.KWARGS.Y = CN()
cfg.KWARGS.Y = cls()
cfg.KWARGS.Y.X = 1

return cfg
Expand Down Expand Up @@ -100,7 +104,7 @@ def test_merge_cfg_from_cfg(self):

# Test: merge from yaml
s = "dummy1"
cfg2 = yacs.config.load_cfg(cfg.dump())
cfg2 = CN.load_cfg(cfg.dump())
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s
Expand Down Expand Up @@ -295,6 +299,48 @@ def test_new_allowed_bad(self):
cfg.merge_from_file("example/config_new_allowed_bad.yaml")


class TestCfgNodeSubclass(unittest.TestCase):
def test_merge_cfg_from_file(self):
with tempfile.NamedTemporaryFile(mode="wt") as f:
cfg = get_cfg(SubCN)
f.write(cfg.dump())
f.flush()
s = cfg.MODEL.TYPE
cfg.MODEL.TYPE = "dummy"
assert cfg.MODEL.TYPE != s
cfg.merge_from_file(f.name)
assert cfg.MODEL.TYPE == s

def test_merge_cfg_from_list(self):
cfg = get_cfg(SubCN)
opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2]
assert len(cfg.TRAIN.SCALES) > 0
assert cfg.TRAIN.SCALES[0] != 100
assert cfg.MODEL.TYPE != "foobar"
assert cfg.NUM_GPUS != 2
cfg.merge_from_list(opts)
assert type(cfg.TRAIN.SCALES) is tuple
assert len(cfg.TRAIN.SCALES) == 1
assert cfg.TRAIN.SCALES[0] == 100
assert cfg.MODEL.TYPE == "foobar"
assert cfg.NUM_GPUS == 2

def test_merge_cfg_from_cfg(self):
cfg = get_cfg(SubCN)
cfg2 = get_cfg(SubCN)
s = "dummy0"
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s

# Test: merge from yaml
s = "dummy1"
cfg2 = SubCN.load_cfg(cfg.dump())
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s


if __name__ == "__main__":
logging.basicConfig()
yacs_logger = logging.getLogger("yacs.config")
Expand Down

0 comments on commit 3ba3ecf

Please sign in to comment.