From 3ba3ecfafc5b5ef05ab81190ae868740fda5efe5 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Mon, 4 Mar 2019 12:00:37 -0800 Subject: [PATCH] Make several functions that construct CfgNode be classmethod, so that they are polymorphic and can be used to construct a subclass of CfgNode. --- yacs/config.py | 230 +++++++++++++++++++++++++++---------------------- yacs/tests.py | 66 +++++++++++--- 2 files changed, 183 insertions(+), 113 deletions(-) diff --git a/yacs/config.py b/yacs/config.py index 0089703..3e160b0 100644 --- a/yacs/config.py +++ b/yacs/config.py @@ -83,18 +83,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False): # Recursively convert nested dictionaries in init_dict into CfgNodes init_dict = {} if init_dict is None else init_dict key_list = [] if key_list is None else key_list - for k, v in init_dict.items(): - if type(v) is dict: - # Convert dict to CfgNode - init_dict[k] = CfgNode(v, key_list=key_list + [k]) - else: - # Check for valid leaf type or nested CfgNode - _assert_with_logging( - _valid_type(v, allow_cfg_node=True), - "Key {} with value {} is not a valid type; valid types: {}".format( - ".".join(key_list + [k]), type(v), _VALID_TYPES - ), - ) + init_dict = self._create_config_tree_from_dict(init_dict, key_list) super(CfgNode, self).__init__(init_dict) # Manage if the CfgNode is frozen or not self.__dict__[CfgNode.IMMUTABLE] = False @@ -119,6 +108,32 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False): # Allow new attributes after initialisation self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed + @classmethod + def _create_config_tree_from_dict(cls, dic, key_list): + """ + Create a configuration tree using the given dict. + Any dict-like objects inside dict will be treated as a new CfgNode. + + Args: + dic (dict): + key_list (list[str]): a list of names which index this CfgNode from the root. + Currently only used for logging purposes. + """ + dic = copy.deepcopy(dic) + for k, v in dic.items(): + if isinstance(v, dict): + # Convert dict to CfgNode + dic[k] = cls(v, key_list=key_list + [k]) + else: + # Check for valid leaf type or nested CfgNode + _assert_with_logging( + _valid_type(v, allow_cfg_node=False), + "Key {} with value {} is not a valid type; valid types: {}".format( + ".".join(key_list + [k]), type(v), _VALID_TYPES + ), + ) + return dic + def __getattr__(self, name): if name in self: return self[name] @@ -194,7 +209,7 @@ def convert_to_dict(cfg_node, key_list): def merge_from_file(self, cfg_filename): """Load a yaml config file and merge it this CfgNode.""" with open(cfg_filename, "r") as f: - cfg = load_cfg(f) + cfg = self.load_cfg(f) self.merge_from_other_cfg(cfg) def merge_from_other_cfg(self, cfg_other): @@ -226,7 +241,7 @@ def merge_from_list(self, cfg_list): d = d[subkey] subkey = key_list[-1] _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) - value = _decode_cfg_value(v) + value = self._decode_cfg_value(v) value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) d[subkey] = value @@ -310,66 +325,108 @@ def raise_key_rename_error(self, full_key): def is_new_allowed(self): return self.__dict__[CfgNode.NEW_ALLOWED] - -def load_cfg(cfg_file_obj_or_str): - """Load a cfg. Supports loading from: - - A file object backed by a YAML file - - A file object backed by a Python source file that exports an attribute - "cfg" that is either a dict or a CfgNode - - A string that can be parsed as valid YAML - """ - _assert_with_logging( - isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), - "Expected first argument to be of type {} or {}, but it was {}".format( - _FILE_TYPES, str, type(cfg_file_obj_or_str) - ), - ) - if isinstance(cfg_file_obj_or_str, str): - return _load_cfg_from_yaml_str(cfg_file_obj_or_str) - elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): - return _load_cfg_from_file(cfg_file_obj_or_str) - else: - raise NotImplementedError("Impossible to reach here (unless there's a bug)") - - -def _load_cfg_from_file(file_obj): - """Load a config from a YAML file or a Python source file.""" - _, file_extension = os.path.splitext(file_obj.name) - if file_extension in _YAML_EXTS: - return _load_cfg_from_yaml_str(file_obj.read()) - elif file_extension in _PY_EXTS: - return _load_cfg_py_source(file_obj.name) - else: - raise Exception( - "Attempt to load from an unsupported file type {}; " - "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) + @classmethod + def load_cfg(cls, cfg_file_obj_or_str): + """ + Load a cfg. + Args: + cfg_file_obj_or_str (str or file): + Supports loading from: + - A file object backed by a YAML file + - A file object backed by a Python source file that exports an attribute + "cfg" that is either a dict or a CfgNode + - A string that can be parsed as valid YAML + """ + _assert_with_logging( + isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), + "Expected first argument to be of type {} or {}, but it was {}".format( + _FILE_TYPES, str, type(cfg_file_obj_or_str) + ), ) + if isinstance(cfg_file_obj_or_str, str): + return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) + elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): + return cls._load_cfg_from_file(cfg_file_obj_or_str) + else: + raise NotImplementedError("Impossible to reach here (unless there's a bug)") + + @classmethod + def _load_cfg_from_file(cls, file_obj): + """Load a config from a YAML file or a Python source file.""" + _, file_extension = os.path.splitext(file_obj.name) + if file_extension in _YAML_EXTS: + return cls._load_cfg_from_yaml_str(file_obj.read()) + elif file_extension in _PY_EXTS: + return cls._load_cfg_py_source(file_obj.name) + else: + raise Exception( + "Attempt to load from an unsupported file type {}; " + "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) + ) + @classmethod + def _load_cfg_from_yaml_str(cls, str_obj): + """Load a config from a YAML string encoding.""" + cfg_as_dict = yaml.safe_load(str_obj) + return cls(cfg_as_dict) -def _load_cfg_from_yaml_str(str_obj): - """Load a config from a YAML string encoding.""" - cfg_as_dict = yaml.safe_load(str_obj) - return CfgNode(cfg_as_dict) + @classmethod + def _load_cfg_py_source(cls, filename): + """Load a config from a Python source file.""" + module = _load_module_from_file("yacs.config.override", filename) + _assert_with_logging( + hasattr(module, "cfg"), + "Python module from file {} must have 'cfg' attr".format(filename), + ) + VALID_ATTR_TYPES = {dict, CfgNode} + _assert_with_logging( + type(module.cfg) in VALID_ATTR_TYPES, + "Imported module 'cfg' attr must be in {} but is {} instead".format( + VALID_ATTR_TYPES, type(module.cfg) + ), + ) + return cls(module.cfg) + @classmethod + def _decode_cfg_value(cls, value): + """ + Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. -def _load_cfg_py_source(filename): - """Load a config from a Python source file.""" - module = _load_module_from_file("yacs.config.override", filename) - _assert_with_logging( - hasattr(module, "cfg"), - "Python module from file {} must have 'cfg' attr".format(filename), - ) - VALID_ATTR_TYPES = {dict, CfgNode} - _assert_with_logging( - type(module.cfg) in VALID_ATTR_TYPES, - "Imported module 'cfg' attr must be in {} but is {} instead".format( - VALID_ATTR_TYPES, type(module.cfg) - ), - ) - if type(module.cfg) is dict: - return CfgNode(module.cfg) - else: - return module.cfg + If the value is a dict, it will be interpreted as a new CfgNode. + If the value is a str, it will be evaluated as literals. + Otherwise it is returned as-is. + """ + # Configs parsed from raw yaml will contain dictionary keys that need to be + # converted to CfgNode objects + if isinstance(value, dict): + return cls(value) + # All remaining processing is only applied to strings + if not isinstance(value, str): + return value + # Try to interpret `value` as a: + # string, number, tuple, list, dict, boolean, or None + try: + value = literal_eval(value) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return value + + +load_cfg = CfgNode.load_cfg # keep this function in global scope for backward compatibility def _valid_type(value, allow_cfg_node=False): @@ -393,7 +450,7 @@ def _merge_a_into_b(a, b, root, key_list): full_key = ".".join(key_list + [k]) v = copy.deepcopy(v_) - v = _decode_cfg_value(v) + v = b._decode_cfg_value(v) if k in b: v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) @@ -416,39 +473,6 @@ def _merge_a_into_b(a, b, root, key_list): raise KeyError("Non-existent config key: {}".format(full_key)) -def _decode_cfg_value(v): - """Decodes a raw config value (e.g., from a yaml config files or command - line argument) into a Python object. - """ - # Configs parsed from raw yaml will contain dictionary keys that need to be - # converted to CfgNode objects - if isinstance(v, dict): - return CfgNode(v) - # All remaining processing is only applied to strings - if not isinstance(v, str): - return v - # Try to interpret `v` as a: - # string, number, tuple, list, dict, boolean, or None - try: - v = literal_eval(v) - # The following two excepts allow v to pass through when it represents a - # string. - # - # Longer explanation: - # The type of v is always a string (before calling literal_eval), but - # sometimes it *represents* a string and other times a data structure, like - # a list. In the case that v represents a string, what we got back from the - # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is - # ok with '"foo"', but will raise a ValueError if given 'foo'. In other - # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval - # will raise a SyntaxError. - except ValueError: - pass - except SyntaxError: - pass - return v - - def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): """Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if it matches exactly or is one of a few diff --git a/yacs/tests.py b/yacs/tests.py index f6f145f..8941321 100644 --- a/yacs/tests.py +++ b/yacs/tests.py @@ -12,26 +12,30 @@ PY2 = False -def get_cfg(): - cfg = CN() +class SubCN(CN): + pass + + +def get_cfg(cls=CN): + cfg = cls() cfg.NUM_GPUS = 8 - cfg.TRAIN = CN() + cfg.TRAIN = cls() cfg.TRAIN.HYPERPARAMETER_1 = 0.1 cfg.TRAIN.SCALES = (2, 4, 8, 16) - cfg.MODEL = CN() + cfg.MODEL = cls() cfg.MODEL.TYPE = "a_foo_model" # Some extra stuff to test CfgNode.__str__ - cfg.STR = CN() + cfg.STR = cls() cfg.STR.KEY1 = 1 cfg.STR.KEY2 = 2 - cfg.STR.FOO = CN() + cfg.STR.FOO = cls() cfg.STR.FOO.KEY1 = 1 cfg.STR.FOO.KEY2 = 2 - cfg.STR.FOO.BAR = CN() + cfg.STR.FOO.BAR = cls() cfg.STR.FOO.BAR.KEY1 = 1 cfg.STR.FOO.BAR.KEY2 = 2 @@ -44,9 +48,9 @@ def get_cfg(): message="Please update your config fil config file.", ) - cfg.KWARGS = CN(new_allowed=True) + cfg.KWARGS = cls(new_allowed=True) cfg.KWARGS.z = 0 - cfg.KWARGS.Y = CN() + cfg.KWARGS.Y = cls() cfg.KWARGS.Y.X = 1 return cfg @@ -100,7 +104,7 @@ def test_merge_cfg_from_cfg(self): # Test: merge from yaml s = "dummy1" - cfg2 = yacs.config.load_cfg(cfg.dump()) + cfg2 = CN.load_cfg(cfg.dump()) cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s @@ -295,6 +299,48 @@ def test_new_allowed_bad(self): cfg.merge_from_file("example/config_new_allowed_bad.yaml") +class TestCfgNodeSubclass(unittest.TestCase): + def test_merge_cfg_from_file(self): + with tempfile.NamedTemporaryFile(mode="wt") as f: + cfg = get_cfg(SubCN) + f.write(cfg.dump()) + f.flush() + s = cfg.MODEL.TYPE + cfg.MODEL.TYPE = "dummy" + assert cfg.MODEL.TYPE != s + cfg.merge_from_file(f.name) + assert cfg.MODEL.TYPE == s + + def test_merge_cfg_from_list(self): + cfg = get_cfg(SubCN) + opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2] + assert len(cfg.TRAIN.SCALES) > 0 + assert cfg.TRAIN.SCALES[0] != 100 + assert cfg.MODEL.TYPE != "foobar" + assert cfg.NUM_GPUS != 2 + cfg.merge_from_list(opts) + assert type(cfg.TRAIN.SCALES) is tuple + assert len(cfg.TRAIN.SCALES) == 1 + assert cfg.TRAIN.SCALES[0] == 100 + assert cfg.MODEL.TYPE == "foobar" + assert cfg.NUM_GPUS == 2 + + def test_merge_cfg_from_cfg(self): + cfg = get_cfg(SubCN) + cfg2 = get_cfg(SubCN) + s = "dummy0" + cfg2.MODEL.TYPE = s + cfg.merge_from_other_cfg(cfg2) + assert cfg.MODEL.TYPE == s + + # Test: merge from yaml + s = "dummy1" + cfg2 = SubCN.load_cfg(cfg.dump()) + cfg2.MODEL.TYPE = s + cfg.merge_from_other_cfg(cfg2) + assert cfg.MODEL.TYPE == s + + if __name__ == "__main__": logging.basicConfig() yacs_logger = logging.getLogger("yacs.config")