diff --git a/CHANGELOG.md b/CHANGELOG.md index 3469fc47..f78680d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ +# v2.9.2 + +- Mixins can be specified, extended and overwritten via `Experiment().component(mixins=[])` + # v2.9.1 - New nested version syntax `~nested:version` diff --git a/src/machinable/config/interface.py b/src/machinable/config/interface.py index 5200b439..8ebd9b2f 100644 --- a/src/machinable/config/interface.py +++ b/src/machinable/config/interface.py @@ -21,11 +21,23 @@ def _collect_updates(version): if not isinstance(arguments, tuple): arguments = (arguments,) - collection.extend(arguments) + collection.extend([a["components"] for a in arguments if "components" in a]) return collection +def _mark_mixin_as_override(mixin): + if isinstance(mixin, str): + mixin = ("!" + mixin) if not mixin.startswith("!") else mixin + if isinstance(mixin, dict): + mixin["name"] = ( + ("!" + mixin["name"]) + if not mixin["name"].startswith("!") + else mixin["name"] + ) + return mixin + + def mapped_config(config): if isinstance(config, list): return [mapped_config(c) for c in config] @@ -42,6 +54,38 @@ def __init__(self, parsed_config, version=None, default_class=None): self.default_class = default_class self.schema_validation = get_settings()["schema_validation"] + @staticmethod + def call_with_context(function, component, components=None, resources=None): + signature = inspect.signature(function) + payload = OrderedDict() + + for index, (key, parameter) in enumerate(signature.parameters.items()): + if parameter.kind is not parameter.POSITIONAL_OR_KEYWORD: + # disallow *args and **kwargs + raise TypeError( + f"Method only allows simple positional or keyword arguments," + f"for example lambda(node, components, resources)" + ) + + if key == "component": + payload["component"] = mapped_config(component) + elif key == "config": + payload["config"] = config_map(component["args"]) + elif key == "flags": + payload["flags"] = config_map(component["flags"]) + elif components is not None and key == "components": + payload["components"] = mapped_config(components) + elif resources is not None and key == "resources": + payload["resources"] = resources + else: + raise ValueError( + f"Unrecognized argument: '{key}'. " + f"Experiment directory takes the following arguments: " + f"'node', 'components' and 'resources'" + ) + + return function(**payload) + def _get_version(self, name, config): # from yaml file if name.endswith(".yaml") or name.endswith(".json"): @@ -58,10 +102,6 @@ def _get_version(self, name, config): f"Did you register it under 'mixins'?\n" ) - # from yaml string - if not name.startswith("~"): - return yaml.load(name, Loader=yaml.FullLoader) - # from local version version = {} path = name[1:].split(":") @@ -85,39 +125,7 @@ def _get_version(self, name, config): f"Available versions: {[f for f in config.keys() if f.startswith('~')]}\n" ) - @staticmethod - def call_with_context(function, component, components=None, resources=None): - signature = inspect.signature(function) - payload = OrderedDict() - - for index, (key, parameter) in enumerate(signature.parameters.items()): - if parameter.kind is not parameter.POSITIONAL_OR_KEYWORD: - # disallow *args and **kwargs - raise TypeError( - f"Method only allows simple positional or keyword arguments," - f"for example lambda(node, components, resources)" - ) - - if key == "component": - payload["component"] = mapped_config(component) - elif key == "config": - payload["config"] = config_map(component["args"]) - elif key == "flags": - payload["flags"] = config_map(component["flags"]) - elif components is not None and key == "components": - payload["components"] = mapped_config(components) - elif resources is not None and key == "resources": - payload["resources"] = resources - else: - raise ValueError( - f"Unrecognized argument: '{key}'. " - f"Experiment directory takes the following arguments: " - f"'node', 'components' and 'resources'" - ) - - return function(**payload) - - def get_component(self, name, version=None, flags=None): + def get_component(self, name, version=None, mixins=None, flags=None): if name is None: return None @@ -170,9 +178,26 @@ def get_component(self, name, version=None, flags=None): # un-alias origin = self.data["components"]["@"][name] - # mixins - for i, mixin in enumerate(parse_mixins(config["args"].get("_mixins_"))): - mixin_info = {"name": mixin["name"]} + # parse mixins + default_mixins = config["args"].pop("_mixins_", None) + if mixins is None: + mixins = default_mixins + else: + if default_mixins is None: + default_mixins = [] + if not isinstance(default_mixins, (list, tuple)): + default_mixins = [default_mixins] + merged = [] + for m in mixins: + if m == "^": + merged.extend(default_mixins) + else: + merged.append(_mark_mixin_as_override(m)) + mixins = merged + + mixin_specs = [] + for mixin in parse_mixins(mixins): + mixin_spec = {"name": mixin["name"]} if mixin["name"].startswith("+."): raise AttributeError( @@ -194,7 +219,7 @@ def get_component(self, name, version=None, flags=None): ) # un-alias - mixin_info["origin"] = ( + mixin_spec["origin"] = ( "+." + vendor + "." @@ -208,7 +233,7 @@ def get_component(self, name, version=None, flags=None): mixin["vendor"] + "." + mixin["name"] ] ) - mixin_info["origin"] = ( + mixin_spec["origin"] = ( "+." + mixin["vendor"] + "." @@ -223,7 +248,7 @@ def get_component(self, name, version=None, flags=None): else: try: mixin_args = copy.deepcopy(self.data["mixins"][mixin["name"]]) - mixin_info["origin"] = self.data["mixins"]["@"].get( + mixin_spec["origin"] = self.data["mixins"]["@"].get( mixin["name"] ) except KeyError: @@ -235,33 +260,32 @@ def get_component(self, name, version=None, flags=None): ) ) - # the mixin config can be overwritten by the local config so we update the mixin arg and write it back - config["args"] = update_dict(mixin_args["args"], config["args"], copy=True) + if mixin["overrides"]: + # override config + config["args"] = update_dict(config["args"], mixin_args["args"]) + else: + # config overrides mixin + config["args"] = update_dict(mixin_args["args"], config["args"]) - # write information - config["args"]["_mixins_"][i] = mixin_info + mixin_specs.append(mixin_spec) - # parse updates + config["args"]["_mixins_"] = mixin_specs - # merge local update to global updates + # collect versions config["versions"] = [] versions = copy.deepcopy(self.version) if version is not None: + # merge local update to global updates versions.append({"arguments": {"components": copy.deepcopy(version)}}) + # parse updates version = {} - for updates in _collect_updates(versions): - update = updates.get("components", None) - if update is None: - continue - + for update in _collect_updates(versions): if not isinstance(update, tuple): update = (update,) - for k in update: if k is None: continue - # load arguments from machinable.yaml if isinstance(k, str): config["versions"].append(k) k = self._get_version(k, config["args"]) @@ -282,11 +306,9 @@ def get_component(self, name, version=None, flags=None): config["args"], version, preserve_schema=self.schema_validation ) - # remove unused versions + # remove versions config["args"] = { - k: v if not k.startswith("~") else ":" - for k, v in config["args"].items() - if not k.startswith("~") or k in config["versions"] + k: v for k, v in config["args"].items() if not k.startswith("~") } return config @@ -294,7 +316,7 @@ def get_component(self, name, version=None, flags=None): def get(self, component, components=None): component = ExperimentComponent.create(component) node_config = self.get_component( - component.name, component.version, component.flags + component.name, component.version, component.mixins, component.flags ) if components is None: @@ -304,7 +326,10 @@ def get(self, component, components=None): for c in components: subcomponent = ExperimentComponent.create(c) component_config = self.get_component( - subcomponent.name, subcomponent.version, subcomponent.flags + subcomponent.name, + subcomponent.version, + subcomponent.mixins, + subcomponent.flags, ) if component_config is not None: components_config.append(component_config) diff --git a/src/machinable/config/parser.py b/src/machinable/config/parser.py index 5a79e279..092aef59 100644 --- a/src/machinable/config/parser.py +++ b/src/machinable/config/parser.py @@ -26,6 +26,12 @@ def parse_mixins(config, valid_only=False): if "name" not in mixin: raise ValueError(f"Mixin definition '{mixin}' must specify a name") + if mixin["name"].startswith("!"): + mixin["overrides"] = True + mixin["name"] = mixin["name"][1:] + else: + mixin["overrides"] = False + if "attribute" not in mixin: mixin["attribute"] = ( "_" + mixin["name"].replace("+.", "").replace(".", "_") + "_" diff --git a/src/machinable/experiment/experiment.py b/src/machinable/experiment/experiment.py index 41cafdb5..424bd2f6 100644 --- a/src/machinable/experiment/experiment.py +++ b/src/machinable/experiment/experiment.py @@ -8,7 +8,7 @@ class ExperimentComponent(Jsonable): - def __init__(self, name, version=None, checkpoint=None, flags=None): + def __init__(self, name, version=None, mixins=None, checkpoint=None, flags=None): """Experiment components # Arguments @@ -25,6 +25,7 @@ def __init__(self, name, version=None, checkpoint=None, flags=None): """ self.name = name self.version = version + self.mixins = mixins self.checkpoint = checkpoint if flags is None: flags = {} @@ -57,7 +58,7 @@ def create(cls, args: Union[Type, str, Tuple, "ExperimentComponent"]): def unpack(self): if isinstance(self.version, list): return [ - __class__(self.name, v, self.checkpoint, self.flags) + __class__(self.name, v, self.mixins, self.checkpoint, self.flags) for v in self.version ] @@ -67,6 +68,7 @@ def serialize(self): return ( self.name, copy.deepcopy(self.version), + copy.deepcopy(self.mixins), self.checkpoint, copy.deepcopy(self.flags), ) @@ -84,7 +86,10 @@ def __repr__(self): if self.name is None: return "machinable.C(None)" - return f"machinable.C({self.name}, version={self.version}, checkpoint={self.checkpoint}, flags={self.flags})" + return ( + f"machinable.C({self.name}, version={self.version}, mixins={self.mixins}," + f" checkpoint={self.checkpoint}, flags={self.flags})" + ) _latest = [None] @@ -249,13 +254,21 @@ def name(self, name: str): return self def component( - self, name, version=None, checkpoint=None, flags=None, resources=None + self, + name, + version=None, + mixins=None, + checkpoint=None, + flags=None, + resources=None, ): """Adds a component to the experiment # Arguments name: String, the name of the components as defined in the machinable.yaml version: dict|String, a configuration update to override its default config + mixins: List[String], a list of mixins to use. This will override any default mixins + `"^"` will be expanded as the default mixins specified in the machinable.yaml. checkpoint: String, optional URL to a checkpoint file from which the components will be restored flags: dict, optional flags to be passed to the component resources: dict, specifies the resources that are available to the component. @@ -279,7 +292,7 @@ def component( components - List of sub-component specifications """ return self.components( - node=ExperimentComponent(name, version, checkpoint, flags), + node=ExperimentComponent(name, version, mixins, checkpoint, flags), resources=resources, ) diff --git a/tests/config/config_interface_test.py b/tests/config/config_interface_test.py index 8e174fd6..f70a61a1 100644 --- a/tests/config/config_interface_test.py +++ b/tests/config/config_interface_test.py @@ -1,6 +1,6 @@ import pytest -from machinable import Experiment +from machinable import Experiment, C from machinable.config.interface import ConfigInterface from machinable.experiment.parser import parse_experiment from machinable.project import Project @@ -42,6 +42,42 @@ def to_config(project, schedule): return node_config["args"], components_config["args"] +def test_config_mixins(): + test_project = Project("./test_project") + + t = Experiment().component("mixexp", "~test", mixins=["version_mixin"]) + e, m = to_config(test_project, t) + assert "elephant" not in e + assert e["mixed_in"] is True + assert e["foo"] == 1 + + t = Experiment().component("mixexp", "~test", mixins=["^", "version_mixin"]) + e, m = to_config(test_project, t) + assert "elephant" in e + assert e["mixed_in"] is True + assert e["foo"] == 1 + + t = Experiment().component("thenode", "~test", mixins=["version_mixin"]) + e, m = to_config(test_project, t) + assert e["mixed_in"] is True + assert e["foo"] == 1 + + t = Experiment().component("thenode", "~test:nest", mixins=["^", "version_mixin"]) + e, m = to_config(test_project, t) + assert e["mixed_in"] is None + assert e["foo"] == 1 + assert e["ba"] == 2 + + t = Experiment().component( + "thenode", "~three:nested", mixins=["version_mixin", "^"] + ) + e, m = to_config(test_project, t) + assert e["alpha"] == 4 + assert e["added"] == "blocker" + assert e["unaffected"] == "value" + assert e["beta"] == "override" + + def test_config_versioning(): test_project = Project("./test_project") diff --git a/tests/test_project/machinable.yaml b/tests/test_project/machinable.yaml index f7204a49..3d5a5acc 100644 --- a/tests/test_project/machinable.yaml +++ b/tests/test_project/machinable.yaml @@ -19,6 +19,19 @@ mixins: - +.fooba.test=test: - +.fooba.mixins.nested=nested: - +.fooba.experiments.base=experiment_mixin: + - version_mixin: + mixed_in: true + ba: 1 + ~test: + foo: 1 + ~nest: + mixed_in: + ba: 2 + ~three: + alpha: 5 + ~nested: + added: blocker + beta: override components: - thenode: timeout: 0