Skip to content

Commit

Permalink
Add mixins=[] option to extend and overwrite mixins in the experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Nov 24, 2020
1 parent ec63acb commit 325657b
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 69 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

<!-- Please add changes under the Unreleased section that reads 'No current changes' otherwise -->

# v2.9.2

- Mixins can be specified, extended and overwritten via `Experiment().component(mixins=[])`

# v2.9.1

- New nested version syntax `~nested:version`
Expand Down
151 changes: 88 additions & 63 deletions src/machinable/config/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,23 @@ def _collect_updates(version):
if not isinstance(arguments, tuple):
arguments = (arguments,)

collection.extend(arguments)
collection.extend([a["components"] for a in arguments if "components" in a])

return collection


def _mark_mixin_as_override(mixin):
if isinstance(mixin, str):
mixin = ("!" + mixin) if not mixin.startswith("!") else mixin
if isinstance(mixin, dict):
mixin["name"] = (
("!" + mixin["name"])
if not mixin["name"].startswith("!")
else mixin["name"]
)
return mixin


def mapped_config(config):
if isinstance(config, list):
return [mapped_config(c) for c in config]
Expand All @@ -42,6 +54,38 @@ def __init__(self, parsed_config, version=None, default_class=None):
self.default_class = default_class
self.schema_validation = get_settings()["schema_validation"]

@staticmethod
def call_with_context(function, component, components=None, resources=None):
signature = inspect.signature(function)
payload = OrderedDict()

for index, (key, parameter) in enumerate(signature.parameters.items()):
if parameter.kind is not parameter.POSITIONAL_OR_KEYWORD:
# disallow *args and **kwargs
raise TypeError(
f"Method only allows simple positional or keyword arguments,"
f"for example lambda(node, components, resources)"
)

if key == "component":
payload["component"] = mapped_config(component)
elif key == "config":
payload["config"] = config_map(component["args"])
elif key == "flags":
payload["flags"] = config_map(component["flags"])
elif components is not None and key == "components":
payload["components"] = mapped_config(components)
elif resources is not None and key == "resources":
payload["resources"] = resources
else:
raise ValueError(
f"Unrecognized argument: '{key}'. "
f"Experiment directory takes the following arguments: "
f"'node', 'components' and 'resources'"
)

return function(**payload)

def _get_version(self, name, config):
# from yaml file
if name.endswith(".yaml") or name.endswith(".json"):
Expand All @@ -58,10 +102,6 @@ def _get_version(self, name, config):
f"Did you register it under 'mixins'?\n"
)

# from yaml string
if not name.startswith("~"):
return yaml.load(name, Loader=yaml.FullLoader)

# from local version
version = {}
path = name[1:].split(":")
Expand All @@ -85,39 +125,7 @@ def _get_version(self, name, config):
f"Available versions: {[f for f in config.keys() if f.startswith('~')]}\n"
)

@staticmethod
def call_with_context(function, component, components=None, resources=None):
signature = inspect.signature(function)
payload = OrderedDict()

for index, (key, parameter) in enumerate(signature.parameters.items()):
if parameter.kind is not parameter.POSITIONAL_OR_KEYWORD:
# disallow *args and **kwargs
raise TypeError(
f"Method only allows simple positional or keyword arguments,"
f"for example lambda(node, components, resources)"
)

if key == "component":
payload["component"] = mapped_config(component)
elif key == "config":
payload["config"] = config_map(component["args"])
elif key == "flags":
payload["flags"] = config_map(component["flags"])
elif components is not None and key == "components":
payload["components"] = mapped_config(components)
elif resources is not None and key == "resources":
payload["resources"] = resources
else:
raise ValueError(
f"Unrecognized argument: '{key}'. "
f"Experiment directory takes the following arguments: "
f"'node', 'components' and 'resources'"
)

return function(**payload)

def get_component(self, name, version=None, flags=None):
def get_component(self, name, version=None, mixins=None, flags=None):
if name is None:
return None

Expand Down Expand Up @@ -170,9 +178,26 @@ def get_component(self, name, version=None, flags=None):
# un-alias
origin = self.data["components"]["@"][name]

# mixins
for i, mixin in enumerate(parse_mixins(config["args"].get("_mixins_"))):
mixin_info = {"name": mixin["name"]}
# parse mixins
default_mixins = config["args"].pop("_mixins_", None)
if mixins is None:
mixins = default_mixins
else:
if default_mixins is None:
default_mixins = []
if not isinstance(default_mixins, (list, tuple)):
default_mixins = [default_mixins]
merged = []
for m in mixins:
if m == "^":
merged.extend(default_mixins)
else:
merged.append(_mark_mixin_as_override(m))
mixins = merged

mixin_specs = []
for mixin in parse_mixins(mixins):
mixin_spec = {"name": mixin["name"]}

if mixin["name"].startswith("+."):
raise AttributeError(
Expand All @@ -194,7 +219,7 @@ def get_component(self, name, version=None, flags=None):
)

# un-alias
mixin_info["origin"] = (
mixin_spec["origin"] = (
"+."
+ vendor
+ "."
Expand All @@ -208,7 +233,7 @@ def get_component(self, name, version=None, flags=None):
mixin["vendor"] + "." + mixin["name"]
]
)
mixin_info["origin"] = (
mixin_spec["origin"] = (
"+."
+ mixin["vendor"]
+ "."
Expand All @@ -223,7 +248,7 @@ def get_component(self, name, version=None, flags=None):
else:
try:
mixin_args = copy.deepcopy(self.data["mixins"][mixin["name"]])
mixin_info["origin"] = self.data["mixins"]["@"].get(
mixin_spec["origin"] = self.data["mixins"]["@"].get(
mixin["name"]
)
except KeyError:
Expand All @@ -235,33 +260,32 @@ def get_component(self, name, version=None, flags=None):
)
)

# the mixin config can be overwritten by the local config so we update the mixin arg and write it back
config["args"] = update_dict(mixin_args["args"], config["args"], copy=True)
if mixin["overrides"]:
# override config
config["args"] = update_dict(config["args"], mixin_args["args"])
else:
# config overrides mixin
config["args"] = update_dict(mixin_args["args"], config["args"])

# write information
config["args"]["_mixins_"][i] = mixin_info
mixin_specs.append(mixin_spec)

# parse updates
config["args"]["_mixins_"] = mixin_specs

# merge local update to global updates
# collect versions
config["versions"] = []
versions = copy.deepcopy(self.version)
if version is not None:
# merge local update to global updates
versions.append({"arguments": {"components": copy.deepcopy(version)}})

# parse updates
version = {}
for updates in _collect_updates(versions):
update = updates.get("components", None)
if update is None:
continue

for update in _collect_updates(versions):
if not isinstance(update, tuple):
update = (update,)

for k in update:
if k is None:
continue
# load arguments from machinable.yaml
if isinstance(k, str):
config["versions"].append(k)
k = self._get_version(k, config["args"])
Expand All @@ -282,19 +306,17 @@ def get_component(self, name, version=None, flags=None):
config["args"], version, preserve_schema=self.schema_validation
)

# remove unused versions
# remove versions
config["args"] = {
k: v if not k.startswith("~") else ":"
for k, v in config["args"].items()
if not k.startswith("~") or k in config["versions"]
k: v for k, v in config["args"].items() if not k.startswith("~")
}

return config

def get(self, component, components=None):
component = ExperimentComponent.create(component)
node_config = self.get_component(
component.name, component.version, component.flags
component.name, component.version, component.mixins, component.flags
)

if components is None:
Expand All @@ -304,7 +326,10 @@ def get(self, component, components=None):
for c in components:
subcomponent = ExperimentComponent.create(c)
component_config = self.get_component(
subcomponent.name, subcomponent.version, subcomponent.flags
subcomponent.name,
subcomponent.version,
subcomponent.mixins,
subcomponent.flags,
)
if component_config is not None:
components_config.append(component_config)
Expand Down
6 changes: 6 additions & 0 deletions src/machinable/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def parse_mixins(config, valid_only=False):
if "name" not in mixin:
raise ValueError(f"Mixin definition '{mixin}' must specify a name")

if mixin["name"].startswith("!"):
mixin["overrides"] = True
mixin["name"] = mixin["name"][1:]
else:
mixin["overrides"] = False

if "attribute" not in mixin:
mixin["attribute"] = (
"_" + mixin["name"].replace("+.", "").replace(".", "_") + "_"
Expand Down
23 changes: 18 additions & 5 deletions src/machinable/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class ExperimentComponent(Jsonable):
def __init__(self, name, version=None, checkpoint=None, flags=None):
def __init__(self, name, version=None, mixins=None, checkpoint=None, flags=None):
"""Experiment components
# Arguments
Expand All @@ -25,6 +25,7 @@ def __init__(self, name, version=None, checkpoint=None, flags=None):
"""
self.name = name
self.version = version
self.mixins = mixins
self.checkpoint = checkpoint
if flags is None:
flags = {}
Expand Down Expand Up @@ -57,7 +58,7 @@ def create(cls, args: Union[Type, str, Tuple, "ExperimentComponent"]):
def unpack(self):
if isinstance(self.version, list):
return [
__class__(self.name, v, self.checkpoint, self.flags)
__class__(self.name, v, self.mixins, self.checkpoint, self.flags)
for v in self.version
]

Expand All @@ -67,6 +68,7 @@ def serialize(self):
return (
self.name,
copy.deepcopy(self.version),
copy.deepcopy(self.mixins),
self.checkpoint,
copy.deepcopy(self.flags),
)
Expand All @@ -84,7 +86,10 @@ def __repr__(self):
if self.name is None:
return "machinable.C(None)"

return f"machinable.C({self.name}, version={self.version}, checkpoint={self.checkpoint}, flags={self.flags})"
return (
f"machinable.C({self.name}, version={self.version}, mixins={self.mixins},"
f" checkpoint={self.checkpoint}, flags={self.flags})"
)


_latest = [None]
Expand Down Expand Up @@ -249,13 +254,21 @@ def name(self, name: str):
return self

def component(
self, name, version=None, checkpoint=None, flags=None, resources=None
self,
name,
version=None,
mixins=None,
checkpoint=None,
flags=None,
resources=None,
):
"""Adds a component to the experiment
# Arguments
name: String, the name of the components as defined in the machinable.yaml
version: dict|String, a configuration update to override its default config
mixins: List[String], a list of mixins to use. This will override any default mixins
`"^"` will be expanded as the default mixins specified in the machinable.yaml.
checkpoint: String, optional URL to a checkpoint file from which the components will be restored
flags: dict, optional flags to be passed to the component
resources: dict, specifies the resources that are available to the component.
Expand All @@ -279,7 +292,7 @@ def component(
components - List of sub-component specifications
"""
return self.components(
node=ExperimentComponent(name, version, checkpoint, flags),
node=ExperimentComponent(name, version, mixins, checkpoint, flags),
resources=resources,
)

Expand Down
Loading

0 comments on commit 325657b

Please sign in to comment.