-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gh-7563: add config wrapper #7730
base: dev
Are you sure you want to change the base?
Changes from 8 commits
5de6a84
22518de
734c115
a9a30e7
4a692da
2e46036
8785691
b187218
cbb4b41
6a4564f
36a6fea
c807545
7752cba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,24 @@ | |
from collections.abc import Mapping, Sequence | ||
from importlib import import_module | ||
from pprint import pformat | ||
from typing import Any | ||
from typing import Any, Callable | ||
|
||
from monai.bundle.utils import EXPR_KEY | ||
from monai.utils import CompInitMode, ensure_tuple, first, instantiate, optional_import, run_debug, run_eval | ||
|
||
__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent", "Instantiable"] | ||
|
||
from monai.utils.feature_flag import FeatureFlag | ||
|
||
CONFIG_COMPONENT_KEY_MODE = "_mode_" | ||
johnzielke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CONFIG_COMPONENT_KEY_DESC = "_desc_" | ||
CONFIG_COMPONENT_KEY_REQUIRES = "_requires_" | ||
CONFIG_COMPONENT_KEY_DISABLED = "_disabled_" | ||
CONFIG_COMPONENT_KEY_TARGET = "_target_" | ||
CONFIG_COMPONENT_KEY_WRAPPER = "_wrapper_" | ||
|
||
_wrapper_feature_flag = FeatureFlag("CONFIG_WRAPPER", default=False) | ||
|
||
|
||
class Instantiable(ABC): | ||
""" | ||
|
@@ -166,7 +177,7 @@ class ConfigComponent(ConfigItem, Instantiable): | |
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to | ||
represent a component of `class` or `function` and supports instantiation. | ||
|
||
Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: | ||
Currently, four special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: | ||
|
||
- class or function identifier of the python module, specified by ``"_target_"``, | ||
indicating a monai built-in Python class or function such as ``"LoadImageDict"``, | ||
|
@@ -183,6 +194,12 @@ class ConfigComponent(ConfigItem, Instantiable): | |
- ``"default"``: returns ``component(**kwargs)`` | ||
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)`` | ||
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` | ||
- ``"_wrapper_"`` (optional): a callable that wraps the instantiation of the component. | ||
This feature is currently experimental and hidden behind a feature flag. To enable it, set the | ||
environment variable ``MONAI_FEATURE_ENABLED_CONFIG_WRAPPER=1`` or | ||
call monai.bundle.config_item._wrapper_feature_flag.enable(). | ||
The callable should take the instantiated component as input and return the wrapped component. | ||
A use case of this can be torch.compile(). See the Config Guide for more details. | ||
|
||
Other fields in the config content are input arguments to the python module. | ||
|
||
|
@@ -210,7 +227,14 @@ class ConfigComponent(ConfigItem, Instantiable): | |
|
||
""" | ||
|
||
non_arg_keys = {"_target_", "_disabled_", "_requires_", "_desc_", "_mode_"} | ||
non_arg_keys = { | ||
johnzielke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CONFIG_COMPONENT_KEY_TARGET, | ||
CONFIG_COMPONENT_KEY_DISABLED, | ||
CONFIG_COMPONENT_KEY_REQUIRES, | ||
CONFIG_COMPONENT_KEY_DESC, | ||
CONFIG_COMPONENT_KEY_MODE, | ||
CONFIG_COMPONENT_KEY_WRAPPER, | ||
} | ||
|
||
def __init__( | ||
self, | ||
|
@@ -231,7 +255,7 @@ def is_instantiable(config: Any) -> bool: | |
config: input config content to check. | ||
|
||
""" | ||
return isinstance(config, Mapping) and "_target_" in config | ||
return isinstance(config, Mapping) and CONFIG_COMPONENT_KEY_TARGET in config | ||
|
||
def resolve_module_name(self): | ||
""" | ||
|
@@ -240,7 +264,7 @@ def resolve_module_name(self): | |
|
||
""" | ||
config = dict(self.get_config()) | ||
target = config.get("_target_") | ||
target = config.get(CONFIG_COMPONENT_KEY_TARGET) | ||
if not isinstance(target, str): | ||
return target # for feature discussed in project-monai/monai#5852 | ||
|
||
|
@@ -262,34 +286,77 @@ def resolve_args(self): | |
Utility function used in `instantiate()` to resolve the arguments from current config content. | ||
|
||
""" | ||
return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys} | ||
return { | ||
k: v | ||
for k, v in self.get_config().items() | ||
if (k not in self.non_arg_keys) or (k == CONFIG_COMPONENT_KEY_WRAPPER and not _wrapper_feature_flag.enabled) | ||
} | ||
|
||
def is_disabled(self) -> bool: | ||
""" | ||
Utility function used in `instantiate()` to check whether to skip the instantiation. | ||
|
||
""" | ||
_is_disabled = self.get_config().get("_disabled_", False) | ||
_is_disabled = self.get_config().get(CONFIG_COMPONENT_KEY_DISABLED, False) | ||
return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) | ||
|
||
def _get_wrapper(self) -> None | Callable[[object], object]: | ||
""" | ||
Utility function used in `instantiate()` to check if a wrapper is specified in the config. | ||
|
||
""" | ||
wrapper = self.get_config().get(CONFIG_COMPONENT_KEY_WRAPPER, None) | ||
if _wrapper_feature_flag.enabled: | ||
if wrapper is not None: | ||
if callable(wrapper): | ||
return wrapper # type: ignore | ||
else: | ||
raise ValueError( | ||
f"wrapper must be a callable, but got type {type(wrapper)}: {wrapper}." | ||
"make sure all references are resolved before calling instantiate " | ||
"and the wrapper is a callable." | ||
) | ||
elif wrapper is not None: | ||
warnings.warn( | ||
f"ConfigComponent: {self.get_id()} has a key {CONFIG_COMPONENT_KEY_WRAPPER}. " | ||
"Since the feature flag CONFIG_WRAPPER is not enabled, the key will be treated as a normal config key. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like in this case the wrapper will become None instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I'm not quite following, can you explain a bit further? |
||
"In future versions of MONAI, this key might be reserved for the wrapper functionality." | ||
) | ||
return None | ||
|
||
def instantiate(self, **kwargs: Any) -> object: | ||
""" | ||
Instantiate component based on ``self.config`` content. | ||
The target component must be a `class` or a `function`, otherwise, return `None`. | ||
|
||
Args: | ||
kwargs: args to override / add the config args when instantiation. | ||
kwargs: instantiate_kwargs to override / add the config instantiate_kwargs when instantiation. | ||
|
||
""" | ||
if not self.is_instantiable(self.get_config()) or self.is_disabled(): | ||
# if not a class or function or marked as `disabled`, skip parsing and return `None` | ||
return None | ||
|
||
modname = self.resolve_module_name() | ||
mode = self.get_config().get("_mode_", CompInitMode.DEFAULT) | ||
args = self.resolve_args() | ||
args.update(kwargs) | ||
return instantiate(modname, mode, **args) | ||
mode = self.get_config().get(CONFIG_COMPONENT_KEY_MODE, CompInitMode.DEFAULT) | ||
instantiate_kwargs = self.resolve_args() | ||
instantiate_kwargs.update(kwargs) | ||
wrapper = self._get_wrapper() | ||
if wrapper is not None: | ||
return wrapper(instantiate(modname, mode, **instantiate_kwargs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will the wrapper here accept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the wrapper key should be "instantiated" like any other key, and therefore will accept kwargs. So by using mode: "callable" for torch.compile, you can bind any kwargs to it, e.g. pseudo-code:
So it doesn't accept them in this line, but should already have them bound by here. |
||
|
||
try: | ||
return instantiate(modname, mode, **instantiate_kwargs) | ||
except Exception as e: | ||
if _wrapper_feature_flag.enabled and self.get_id().endswith(CONFIG_COMPONENT_KEY_WRAPPER): | ||
raise RuntimeError( | ||
f"Failed to instantiate {self}. Make sure you are returning a partial " | ||
f"(you might need to add {CONFIG_COMPONENT_KEY_MODE}:callable, " | ||
"especially when using specifying a class)." | ||
) from e | ||
else: | ||
# re-raise the exception if not using the wrapper | ||
raise | ||
|
||
|
||
class ConfigExpression(ConfigItem): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from __future__ import annotations | ||
|
||
import os | ||
from contextlib import contextmanager | ||
|
||
FEATURE_FLAG_PREFIX = "MONAI_FEATURE_ENABLED_" | ||
|
||
|
||
class FeatureFlag: | ||
johnzielke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, name: str, *, default: bool = False): | ||
self.name = name | ||
self._enabled: bool | None = None | ||
self.default = default | ||
|
||
def _get_from_env(self): | ||
return os.getenv(FEATURE_FLAG_PREFIX + self.name, None) | ||
|
||
@property | ||
def enabled(self) -> bool: | ||
if self._enabled is None: | ||
env = self._get_from_env() | ||
if env is None: | ||
self._enabled = self.default | ||
else: | ||
self._enabled = env.lower() in ["true", "1", "yes"] | ||
return self._enabled | ||
|
||
@enabled.setter | ||
def enabled(self, value: bool) -> None: | ||
self._enabled = value | ||
|
||
def enable(self): | ||
self.enabled = True | ||
|
||
def disable(self): | ||
self.enabled = False | ||
|
||
def __str__(self): | ||
return f"{self.name}: {self.enabled}, default: {self.default}" | ||
|
||
|
||
@contextmanager | ||
def with_feature_flag(feature_flag: FeatureFlag, enabled: bool): # type: ignore | ||
original = feature_flag.enabled | ||
feature_flag.enabled = enabled | ||
try: | ||
yield | ||
finally: | ||
feature_flag.enabled = original |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the wrapper can only be callable, do we need add this mode here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might/probably are use cases where you would instantiate a factory here (class with a call method) or have a function that returns another function (in fact torch.compile probably does this, as most decorators supporting arguments do)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:
Does this meet your needs?
Like how we use DDP here:
https://github.com/Project-MONAI/model-zoo/blob/dev/models/spleen_ct_segmentation/configs/multi_gpu_train.json#L3
cc @Nic-Ma @ericspod
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have used this pattern before which I feel for my use cases is fine. I'd like to hear if there's other cases this PR makes more sense for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that patterns works in many cases, but it requires you to move definitions to other keys, which might be undesirable for multiple reasons:
@model_decorator(@model_base)
, but that adds a lot of visual noise in my opinionI'll try to compile some other use cases later, but one I could see would be to wrap existing Datasets in CacheDatasets or similar. In that case you would have often have a dictionary of train, val, test and the datasets. Of course the same pattern can be applied again, but it makes the configs harder to read in my opinion.