Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for required parameters and help strings (#27) #28

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions yacs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import io
import sys

# Flag for py2 and py3 compatibility to use when separate code paths are necessary
# When _PY2 is False, we assume Python 3 is in use
_PY2 = sys.version_info.major == 2

# Filename extensions for loading configs from files
_YAML_EXTS = {"", ".yaml", ".yml"}
_PY_EXTS = {".py"}

# py2 and py3 compatibility for checking file object type
# We simply use this to infer py2 vs py3
if _PY2:
_FILE_TYPES = (file, io.IOBase)
else:
_FILE_TYPES = (io.IOBase,)

# CfgNodes can only contain a limited set of valid types
_VALID_TYPES = {tuple, list, str, int, float, bool}
# py2 allow for str and unicode
if _PY2:
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
126 changes: 88 additions & 38 deletions yacs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,15 @@
"""

import copy
import io
import logging
import os
import sys
from ast import literal_eval
from collections import defaultdict

import yaml

# Flag for py2 and py3 compatibility to use when separate code paths are necessary
# When _PY2 is False, we assume Python 3 is in use
_PY2 = sys.version_info.major == 2

# Filename extensions for loading configs from files
_YAML_EXTS = {"", ".yaml", ".yml"}
_PY_EXTS = {".py"}

# py2 and py3 compatibility for checking file object type
# We simply use this to infer py2 vs py3
if _PY2:
_FILE_TYPES = (file, io.IOBase)
else:
_FILE_TYPES = (io.IOBase,)

# CfgNodes can only contain a limited set of valid types
_VALID_TYPES = {tuple, list, str, int, float, bool}
# py2 allow for str and unicode
if _PY2:
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
from yacs import _PY2, _PY_EXTS, _VALID_TYPES, _FILE_TYPES, _YAML_EXTS
from yacs.params import Parameter, Required

# Utilities for importing modules from file paths
if _PY2:
Expand All @@ -57,6 +38,7 @@
else:
import importlib.util


logger = logging.getLogger(__name__)


Expand All @@ -70,6 +52,7 @@ class CfgNode(dict):
DEPRECATED_KEYS = "__deprecated_keys__"
RENAMED_KEYS = "__renamed_keys__"
NEW_ALLOWED = "__new_allowed__"
HELP = "__help__"

def __init__(self, init_dict=None, key_list=None, new_allowed=False):
"""
Expand All @@ -83,7 +66,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
init_dict = self._create_config_tree_from_dict(init_dict, key_list)
init_dict, help_dict = self._create_config_tree_from_dict(init_dict, key_list)
super(CfgNode, self).__init__(init_dict)
# Manage if the CfgNode is frozen or not
self.__dict__[CfgNode.IMMUTABLE] = False
Expand All @@ -107,6 +90,7 @@ def __init__(self, init_dict=None, key_list=None, new_allowed=False):

# Allow new attributes after initialisation
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
self.__dict__[CfgNode.HELP] = help_dict

@classmethod
def _create_config_tree_from_dict(cls, dic, key_list):
Expand All @@ -120,19 +104,25 @@ def _create_config_tree_from_dict(cls, dic, key_list):
Currently only used for logging purposes.
"""
dic = copy.deepcopy(dic)
help_dict = defaultdict(str)

for k, v in dic.items():
# Check for valid leaf type or nested CfgNode
if isinstance(v, dict):
# Convert dict to CfgNode
dic[k] = cls(v, key_list=key_list + [k])
else:
# Check for valid leaf type or nested CfgNode
if isinstance(v, Parameter):
v = v.value
help_dict[k] = v.help

_assert_with_logging(
_valid_type(v, allow_cfg_node=False),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list + [k]), type(v), _VALID_TYPES
),
)
return dic
return dic, help_dict

def __getattr__(self, name):
if name in self:
Expand All @@ -152,14 +142,38 @@ def __setattr__(self, name, value):
name not in self.__dict__,
"Invalid attempt to modify internal CfgNode state: {}".format(name),
)
_assert_with_logging(
_valid_type(value, allow_cfg_node=True),
"Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES
),
)

self[name] = value
if isinstance(self.get(name), Required):
if isinstance(value, Parameter):
value_type = value.type
else:
value_type = type(value)

required_type = self.get(name).type
_assert_with_logging(
value_type == required_type,
"Invalid type {} for key {}; required type = {}".format(
value_type, name, required_type
)
)

if isinstance(value, Parameter):
if not value.required:
self[name] = value.value
else:
self[name] = value

self.__dict__[CfgNode.HELP][name] = value.description

else:
_assert_with_logging(
_valid_type(value, allow_cfg_node=True),
"Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES
),
)

self[name] = value

def __str__(self):
def _indent(s_, num_spaces):
Expand All @@ -177,6 +191,9 @@ def _indent(s_, num_spaces):
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
description = self.get_description(k)
if description:
attr_str += '\t\t[{}]'.format(description)
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
Expand All @@ -190,12 +207,20 @@ def dump(self, **kwargs):

def convert_to_dict(cfg_node, key_list):
if not isinstance(cfg_node, CfgNode):
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES
),
)
if isinstance(cfg_node, Parameter):
_assert_with_logging(
not cfg_node.required,
"Key {} is required to be overloaded a parameter of type {}".format(
".".join(key_list), cfg_node.type.__name__
)
)
else:
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES
),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
Expand Down Expand Up @@ -247,6 +272,10 @@ def merge_from_list(self, cfg_list):

def freeze(self):
"""Make this CfgNode and all of its children immutable."""
_assert_with_logging(
not self._has_required_params(),
"Config has unset required parameters."
)
self._immutable(True)

def defrost(self):
Expand All @@ -270,6 +299,15 @@ def _immutable(self, is_immutable):
if isinstance(v, CfgNode):
v._immutable(is_immutable)

def _has_required_params(self):
for k, v in self.items():
if isinstance(v, CfgNode):
if v._has_required_params():
return True
if isinstance(v, Required):
return True
return False

def clone(self):
"""Recursively copy this CfgNode."""
return copy.deepcopy(self)
Expand Down Expand Up @@ -425,13 +463,17 @@ def _decode_cfg_value(cls, value):
pass
return value

def get_description(self, key):
return self.__dict__[CfgNode.HELP][key]


load_cfg = (
CfgNode.load_cfg
) # keep this function in global scope for backward compatibility


def _valid_type(value, allow_cfg_node=False):
value = value.value if isinstance(value, Parameter) else value
return (type(value) in _VALID_TYPES) or (
allow_cfg_node and isinstance(value, CfgNode)
)
Expand Down Expand Up @@ -489,6 +531,14 @@ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
if replacement_type == original_type:
return replacement

if isinstance(original, Parameter):
if isinstance(replacement, Parameter):
if replacement.type == original.type:
return replacement
else:
if replacement_type == original.type:
return replacement

# Cast replacement from from_type to to_type if the replacement and original
# types match from_type and to_type
def conditional_cast(from_type, to_type):
Expand Down
31 changes: 31 additions & 0 deletions yacs/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from yacs import _VALID_TYPES


class Parameter:
def __init__(self, value, value_type=None, description=''):
self.value = value
self.description = description
if value_type is not None:
assert value_type in _VALID_TYPES, ""
self.type = value_type
else:
self.type = type(value)
self.required = False

def __str__(self):
return str(self.value)

def __repr__(self):
return str(self.value)


class Required(Parameter):
def __init__(self, value_type, description=''):
super(Required, self).__init__(None, value_type=value_type, description=description)
self.required = True

def __repr__(self):
return 'required({})'.format(self.type.__name__)

def __str__(self):
return 'required({})'.format(self.type.__name__)
Loading