From 6adf4d44b819862f3ede535ff89da6a437409be4 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 6 Aug 2019 14:46:03 +0200 Subject: [PATCH 1/2] added support for required parameters and help descriptions --- yacs/__init__.py | 23 ++++++++++ yacs/config.py | 113 +++++++++++++++++++++++++++++++---------------- yacs/params.py | 31 +++++++++++++ yacs/tests.py | 37 +++++++++++++++- 4 files changed, 164 insertions(+), 40 deletions(-) create mode 100644 yacs/params.py diff --git a/yacs/__init__.py b/yacs/__init__.py index e69de29..f54add1 100644 --- a/yacs/__init__.py +++ b/yacs/__init__.py @@ -0,0 +1,23 @@ +import io +import sys + +# Flag for py2 and py3 compatibility to use when separate code paths are necessary +# When _PY2 is False, we assume Python 3 is in use +_PY2 = sys.version_info.major == 2 + +# Filename extensions for loading configs from files +_YAML_EXTS = {"", ".yaml", ".yml"} +_PY_EXTS = {".py"} + +# py2 and py3 compatibility for checking file object type +# We simply use this to infer py2 vs py3 +if _PY2: + _FILE_TYPES = (file, io.IOBase) +else: + _FILE_TYPES = (io.IOBase,) + +# CfgNodes can only contain a limited set of valid types +_VALID_TYPES = {tuple, list, str, int, float, bool} +# py2 allow for str and unicode +if _PY2: + _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 diff --git a/yacs/config.py b/yacs/config.py index e39898d..7a78c3e 100644 --- a/yacs/config.py +++ b/yacs/config.py @@ -21,34 +21,15 @@ """ import copy -import io import logging import os -import sys from ast import literal_eval +from collections import defaultdict import yaml -# Flag for py2 and py3 compatibility to use when separate code paths are necessary -# When _PY2 is False, we assume Python 3 is in use -_PY2 = sys.version_info.major == 2 - -# Filename extensions for loading configs from files -_YAML_EXTS = {"", ".yaml", ".yml"} -_PY_EXTS = {".py"} - -# py2 and py3 compatibility for checking file object type -# We simply use this to infer py2 vs py3 -if _PY2: - _FILE_TYPES = (file, io.IOBase) -else: - _FILE_TYPES = (io.IOBase,) - -# CfgNodes can only contain a limited set of valid types -_VALID_TYPES = {tuple, list, str, int, float, bool} -# py2 allow for str and unicode -if _PY2: - _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 +from yacs import _PY2, _PY_EXTS, _VALID_TYPES, _FILE_TYPES, _YAML_EXTS +from yacs.params import Parameter, Required # Utilities for importing modules from file paths if _PY2: @@ -57,6 +38,7 @@ else: import importlib.util + logger = logging.getLogger(__name__) @@ -70,6 +52,7 @@ class CfgNode(dict): DEPRECATED_KEYS = "__deprecated_keys__" RENAMED_KEYS = "__renamed_keys__" NEW_ALLOWED = "__new_allowed__" + HELP = "__help__" def __init__(self, init_dict=None, key_list=None, new_allowed=False): """ @@ -83,7 +66,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False): # Recursively convert nested dictionaries in init_dict into CfgNodes init_dict = {} if init_dict is None else init_dict key_list = [] if key_list is None else key_list - init_dict = self._create_config_tree_from_dict(init_dict, key_list) + init_dict, help_dict = self._create_config_tree_from_dict(init_dict, key_list) super(CfgNode, self).__init__(init_dict) # Manage if the CfgNode is frozen or not self.__dict__[CfgNode.IMMUTABLE] = False @@ -107,6 +90,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False): # Allow new attributes after initialisation self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed + self.__dict__[CfgNode.HELP] = help_dict @classmethod def _create_config_tree_from_dict(cls, dic, key_list): @@ -120,19 +104,25 @@ def _create_config_tree_from_dict(cls, dic, key_list): Currently only used for logging purposes. """ dic = copy.deepcopy(dic) + help_dict = defaultdict(str) + for k, v in dic.items(): + # Check for valid leaf type or nested CfgNode if isinstance(v, dict): # Convert dict to CfgNode dic[k] = cls(v, key_list=key_list + [k]) else: - # Check for valid leaf type or nested CfgNode + if isinstance(v, Parameter): + v = v.value + help_dict[k] = v.help + _assert_with_logging( _valid_type(v, allow_cfg_node=False), "Key {} with value {} is not a valid type; valid types: {}".format( ".".join(key_list + [k]), type(v), _VALID_TYPES ), ) - return dic + return dic, help_dict def __getattr__(self, name): if name in self: @@ -152,14 +142,38 @@ def __setattr__(self, name, value): name not in self.__dict__, "Invalid attempt to modify internal CfgNode state: {}".format(name), ) - _assert_with_logging( - _valid_type(value, allow_cfg_node=True), - "Invalid type {} for key {}; valid types = {}".format( - type(value), name, _VALID_TYPES - ), - ) - self[name] = value + if isinstance(self.get(name), Required): + if isinstance(value, Parameter): + value_type = value.type + else: + value_type = type(value) + + required_type = self.get(name).type + _assert_with_logging( + value_type == required_type, + "Invalid type {} for key {}; required type = {}".format( + value_type, name, required_type + ) + ) + + if isinstance(value, Parameter): + if not value.required: + self[name] = value.value + else: + self[name] = value + + self.__dict__[CfgNode.HELP][name] = value.description + + else: + _assert_with_logging( + _valid_type(value, allow_cfg_node=True), + "Invalid type {} for key {}; valid types = {}".format( + type(value), name, _VALID_TYPES + ), + ) + + self[name] = value def __str__(self): def _indent(s_, num_spaces): @@ -177,6 +191,9 @@ def _indent(s_, num_spaces): for k, v in sorted(self.items()): seperator = "\n" if isinstance(v, CfgNode) else " " attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + description = self.get_description(k) + if description: + attr_str += '\t\t[{}]'.format(description) attr_str = _indent(attr_str, 2) s.append(attr_str) r += "\n".join(s) @@ -190,12 +207,20 @@ def dump(self, **kwargs): def convert_to_dict(cfg_node, key_list): if not isinstance(cfg_node, CfgNode): - _assert_with_logging( - _valid_type(cfg_node), - "Key {} with value {} is not a valid type; valid types: {}".format( - ".".join(key_list), type(cfg_node), _VALID_TYPES - ), - ) + if isinstance(cfg_node, Parameter): + _assert_with_logging( + not cfg_node.required, + "Key {} is required to be overloaded a parameter of type {}".format( + ".".join(key_list), cfg_node.type.__name__ + ) + ) + else: + _assert_with_logging( + _valid_type(cfg_node), + "Key {} with value {} is not a valid type; valid types: {}".format( + ".".join(key_list), type(cfg_node), _VALID_TYPES + ), + ) return cfg_node else: cfg_dict = dict(cfg_node) @@ -425,6 +450,9 @@ def _decode_cfg_value(cls, value): pass return value + def get_description(self, key): + return self.__dict__[CfgNode.HELP][key] + load_cfg = ( CfgNode.load_cfg @@ -432,6 +460,7 @@ def _decode_cfg_value(cls, value): def _valid_type(value, allow_cfg_node=False): + value = value.value if isinstance(value, Parameter) else value return (type(value) in _VALID_TYPES) or ( allow_cfg_node and isinstance(value, CfgNode) ) @@ -489,6 +518,14 @@ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): if replacement_type == original_type: return replacement + if isinstance(original, Parameter): + if isinstance(replacement, Parameter): + if replacement.type == original.type: + return replacement + else: + if replacement_type == original.type: + return replacement + # Cast replacement from from_type to to_type if the replacement and original # types match from_type and to_type def conditional_cast(from_type, to_type): diff --git a/yacs/params.py b/yacs/params.py new file mode 100644 index 0000000..c744284 --- /dev/null +++ b/yacs/params.py @@ -0,0 +1,31 @@ +from yacs import _VALID_TYPES + + +class Parameter: + def __init__(self, value, value_type=None, description=''): + self.value = value + self.description = description + if value_type is not None: + assert value_type in _VALID_TYPES, "" + self.type = value_type + else: + self.type = type(value) + self.required = False + + def __str__(self): + return str(self.value) + + def __repr__(self): + return str(self.value) + + +class Required(Parameter): + def __init__(self, value_type, description=''): + super(Required, self).__init__(None, value_type=value_type, description=description) + self.required = True + + def __repr__(self): + return 'required({})'.format(self.type.__name__) + + def __str__(self): + return 'required({})'.format(self.type.__name__) diff --git a/yacs/tests.py b/yacs/tests.py index 6748214..d55ca74 100644 --- a/yacs/tests.py +++ b/yacs/tests.py @@ -4,6 +4,7 @@ import yacs.config from yacs.config import CfgNode as CN +from yacs.params import Parameter, Required try: _ignore = unicode # noqa: F821 @@ -21,8 +22,11 @@ def get_cfg(cls=CN): cfg.NUM_GPUS = 8 + # required keys + cfg.REQUIRED_FLOAT = Required(float, description="a required float parameter") + cfg.TRAIN = cls() - cfg.TRAIN.HYPERPARAMETER_1 = 0.1 + cfg.TRAIN.HYPERPARAMETER_1 = Parameter(0.1, description='hyperparameter 1') cfg.TRAIN.SCALES = (2, 4, 8, 16) cfg.MODEL = cls() @@ -96,6 +100,8 @@ def test_copy_cfg(self): def test_merge_cfg_from_cfg(self): # Test: merge from clone cfg = get_cfg() + cfg.REQUIRED_FLOAT = 1.0 + s = "dummy0" cfg2 = cfg.clone() cfg2.MODEL.TYPE = s @@ -133,6 +139,13 @@ def test_merge_cfg_from_cfg(self): assert type(cfg.TRAIN.SCALES) is tuple assert cfg.TRAIN.SCALES[0] == 1 + # Test: merge with required key + cfg1 = get_cfg() + cfg2 = CN() + cfg2.REQUIRED_FLOAT = 1.0 + cfg1.merge_from_other_cfg(cfg2) + assert cfg1.REQUIRED_FLOAT == 1.0 + # Test str (bytes) <-> unicode conversion for py2 if PY2: cfg.A_UNICODE_KEY = u"foo" @@ -152,6 +165,8 @@ def test_merge_cfg_from_cfg(self): def test_merge_cfg_from_file(self): with tempfile.NamedTemporaryFile(mode="wt") as f: cfg = get_cfg() + cfg.REQUIRED_FLOAT = 1.0 + f.write(cfg.dump()) f.flush() s = cfg.MODEL.TYPE @@ -205,6 +220,8 @@ def test_deprecated_key_from_file(self): # You should see logger messages like: # "Deprecated config key (ignoring): MODEL.DILATION" cfg = get_cfg() + cfg.REQUIRED_FLOAT = 1.0 + with tempfile.NamedTemporaryFile("wt") as f: cfg2 = cfg.clone() cfg2.MODEL.DILATION = 2 @@ -226,6 +243,8 @@ def test_renamed_key_from_list(self): def test_renamed_key_from_file(self): cfg = get_cfg() + cfg.REQUIRED_FLOAT = 1.0 + with tempfile.NamedTemporaryFile("wt") as f: cfg2 = cfg.clone() cfg2.EXAMPLE = CN() @@ -240,6 +259,8 @@ def test_renamed_key_from_file(self): def test_load_cfg_from_file(self): cfg = get_cfg() + cfg.REQUIRED_FLOAT = 1.0 + with tempfile.NamedTemporaryFile("wt") as f: f.write(cfg.dump()) f.flush() @@ -270,6 +291,7 @@ def test__str__(self): MODEL: TYPE: a_foo_model NUM_GPUS: 8 +REQUIRED_FLOAT: required(float) [a required float parameter] STR: FOO: BAR: @@ -280,7 +302,7 @@ def test__str__(self): KEY1: 1 KEY2: 2 TRAIN: - HYPERPARAMETER_1: 0.1 + HYPERPARAMETER_1: 0.1 [hyperparameter 1] SCALES: (2, 4, 8, 16) """.strip() cfg = get_cfg() @@ -298,11 +320,18 @@ def test_new_allowed_bad(self): with self.assertRaises(KeyError): cfg.merge_from_file("example/config_new_allowed_bad.yaml") + def test_invalid_overload(self): + cfg = get_cfg() + with self.assertRaises(AssertionError): + cfg.REQUIRED_FLOAT = 1 + class TestCfgNodeSubclass(unittest.TestCase): def test_merge_cfg_from_file(self): with tempfile.NamedTemporaryFile(mode="wt") as f: cfg = get_cfg(SubCN) + cfg.REQUIRED_FLOAT = 1.0 + f.write(cfg.dump()) f.flush() s = cfg.MODEL.TYPE @@ -327,7 +356,11 @@ def test_merge_cfg_from_list(self): def test_merge_cfg_from_cfg(self): cfg = get_cfg(SubCN) + cfg.REQUIRED_FLOAT = 1.0 + cfg2 = get_cfg(SubCN) + cfg2.REQUIRED_FLOAT = 1.0 + s = "dummy0" cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) From d30dd78567bb03363854fa8c96167d8046c61dd4 Mon Sep 17 00:00:00 2001 From: miguel Date: Wed, 7 Aug 2019 10:12:32 +0200 Subject: [PATCH 2/2] add check for required parameters when freeze is called --- yacs/config.py | 13 +++++++++++++ yacs/tests.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/yacs/config.py b/yacs/config.py index 7a78c3e..d80e62a 100644 --- a/yacs/config.py +++ b/yacs/config.py @@ -272,6 +272,10 @@ def merge_from_list(self, cfg_list): def freeze(self): """Make this CfgNode and all of its children immutable.""" + _assert_with_logging( + not self._has_required_params(), + "Config has unset required parameters." + ) self._immutable(True) def defrost(self): @@ -295,6 +299,15 @@ def _immutable(self, is_immutable): if isinstance(v, CfgNode): v._immutable(is_immutable) + def _has_required_params(self): + for k, v in self.items(): + if isinstance(v, CfgNode): + if v._has_required_params(): + return True + if isinstance(v, Required): + return True + return False + def clone(self): """Recursively copy this CfgNode.""" return copy.deepcopy(self) diff --git a/yacs/tests.py b/yacs/tests.py index d55ca74..e7553e8 100644 --- a/yacs/tests.py +++ b/yacs/tests.py @@ -88,6 +88,19 @@ def test_immutability(self): a.level1.bar = 1 assert a.level1.level2.foo == 0 + def test_freeze_With_required_params(self): + cfg1 = CN() + cfg1.foo = Required(int) + with self.assertRaises(AssertionError): + cfg1.freeze() + + cfg2 = CN() + cfg2.foo = CN() + cfg2.foo.bar = Required(str) + + with self.assertRaises(AssertionError): + cfg2.freeze() + class TestCfg(unittest.TestCase): def test_copy_cfg(self):