Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-7563: add config wrapper #7730

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
57 changes: 57 additions & 0 deletions docs/source/config_syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,63 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).

### Wrapping config components

> **EXPERIMENTAL FEATURE** This feature is experimental and may be subject to change or removal in future releases.

Sometimes it can be necessary to wrap (i.e. decorate) a component in the config without
shifting the configuration tree one level down.
Take the following configuration as an example:

```json
{
"model": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64]
}
}
```
If we wanted to use `torch.compile` to speed up the model, we would have to write a configuration like this:

```json
{
"model": {
"_target_": "torch::jit::compile",
"model": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64]
}
}
}
```
This means we now need to adjust all references to parameters like `model.spatial_dims` to `model.model.spatial_dims`
throughout our code and configuration.
To avoid this, we can use the `_wrapper_` key to wrap the model in the configuration:

```json
{
"model": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64],
"_wrapper_": {
"_target_": "torch::jit::compile",
"_mode_": "callable"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the wrapper can only be callable, do we need add this mode here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might/probably are use cases where you would instantiate a factory here (class with a call method) or have a function that returns another function (in fact torch.compile probably does this, as most decorators supporting arguments do)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:

from monai.bundle import ConfigParser

config = {
  "model_base": {
      "_target_": "monai.networks.nets.BasicUNet",
      "spatial_dims": 3,
      "in_channels": 1,
      "out_channels": 2,
      "features": [16, 16, 32, 32, 64, 64]
  },
  "model": "$torch.compile(@model_base)"

}
parser = ConfigParser(config=config)
parser.parse()
net = parser.get_parsed_content("model")
print(net)

Does this meet your needs?
Like how we use DDP here:
https://github.com/Project-MONAI/model-zoo/blob/dev/models/spleen_ct_segmentation/configs/multi_gpu_train.json#L3

cc @Nic-Ma @ericspod

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:

We have used this pattern before which I feel for my use cases is fine. I'd like to hear if there's other cases this PR makes more sense for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that patterns works in many cases, but it requires you to move definitions to other keys, which might be undesirable for multiple reasons:

  • A lot of references to this in code and overriding configs need to be adjusted if this has not been planned from the beginning
  • If the thing you are adding is used as a parameter to a parent instantiation using target, you cannot just add the "model_base" key to the same level if that parent class does not handle other kwargs gracefully.
  • If your wrapping function has parameters (for example the dynamic shapes in torch.compile), you need to specify those in a single string making it harder to modify these with configs. Of course you could also introduce another config called model_decorator, and then have model be @model_decorator(@model_base), but that adds a lot of visual noise in my opinion

I'll try to compile some other use cases later, but one I could see would be to wrap existing Datasets in CacheDatasets or similar. In that case you would have often have a dictionary of train, val, test and the datasets. Of course the same pattern can be applied again, but it makes the configs harder to read in my opinion.

}
}
}
```

Note that when accessing `@model` in the configuration, the model object will be the compiled model now.

## The command line interface

In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.
Expand Down
91 changes: 79 additions & 12 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@
from collections.abc import Mapping, Sequence
from importlib import import_module
from pprint import pformat
from typing import Any
from typing import Any, Callable

from monai.bundle.utils import EXPR_KEY
from monai.utils import CompInitMode, ensure_tuple, first, instantiate, optional_import, run_debug, run_eval

__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent", "Instantiable"]

from monai.utils.feature_flag import FeatureFlag

CONFIG_COMPONENT_KEY_MODE = "_mode_"
johnzielke marked this conversation as resolved.
Show resolved Hide resolved
CONFIG_COMPONENT_KEY_DESC = "_desc_"
CONFIG_COMPONENT_KEY_REQUIRES = "_requires_"
CONFIG_COMPONENT_KEY_DISABLED = "_disabled_"
CONFIG_COMPONENT_KEY_TARGET = "_target_"
CONFIG_COMPONENT_KEY_WRAPPER = "_wrapper_"

_wrapper_feature_flag = FeatureFlag("CONFIG_WRAPPER", default=False)


class Instantiable(ABC):
"""
Expand Down Expand Up @@ -166,7 +177,7 @@ class ConfigComponent(ConfigItem, Instantiable):
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to
represent a component of `class` or `function` and supports instantiation.

Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
Currently, four special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:

- class or function identifier of the python module, specified by ``"_target_"``,
indicating a monai built-in Python class or function such as ``"LoadImageDict"``,
Expand All @@ -183,6 +194,12 @@ class ConfigComponent(ConfigItem, Instantiable):
- ``"default"``: returns ``component(**kwargs)``
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
- ``"_wrapper_"`` (optional): a callable that wraps the instantiation of the component.
This feature is currently experimental and hidden behind a feature flag. To enable it, set the
environment variable ``MONAI_FEATURE_ENABLED_CONFIG_WRAPPER=1`` or
call monai.bundle.config_item._wrapper_feature_flag.enable().
The callable should take the instantiated component as input and return the wrapped component.
A use case of this can be torch.compile(). See the Config Guide for more details.

Other fields in the config content are input arguments to the python module.

Expand Down Expand Up @@ -210,7 +227,14 @@ class ConfigComponent(ConfigItem, Instantiable):

"""

non_arg_keys = {"_target_", "_disabled_", "_requires_", "_desc_", "_mode_"}
non_arg_keys = {
johnzielke marked this conversation as resolved.
Show resolved Hide resolved
CONFIG_COMPONENT_KEY_TARGET,
CONFIG_COMPONENT_KEY_DISABLED,
CONFIG_COMPONENT_KEY_REQUIRES,
CONFIG_COMPONENT_KEY_DESC,
CONFIG_COMPONENT_KEY_MODE,
CONFIG_COMPONENT_KEY_WRAPPER,
}

def __init__(
self,
Expand All @@ -231,7 +255,7 @@ def is_instantiable(config: Any) -> bool:
config: input config content to check.

"""
return isinstance(config, Mapping) and "_target_" in config
return isinstance(config, Mapping) and CONFIG_COMPONENT_KEY_TARGET in config

def resolve_module_name(self):
"""
Expand All @@ -240,7 +264,7 @@ def resolve_module_name(self):

"""
config = dict(self.get_config())
target = config.get("_target_")
target = config.get(CONFIG_COMPONENT_KEY_TARGET)
if not isinstance(target, str):
return target # for feature discussed in project-monai/monai#5852

Expand All @@ -262,34 +286,77 @@ def resolve_args(self):
Utility function used in `instantiate()` to resolve the arguments from current config content.

"""
return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys}
return {
k: v
for k, v in self.get_config().items()
if (k not in self.non_arg_keys) or (k == CONFIG_COMPONENT_KEY_WRAPPER and not _wrapper_feature_flag.enabled)
}

def is_disabled(self) -> bool:
"""
Utility function used in `instantiate()` to check whether to skip the instantiation.

"""
_is_disabled = self.get_config().get("_disabled_", False)
_is_disabled = self.get_config().get(CONFIG_COMPONENT_KEY_DISABLED, False)
return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled)

def _get_wrapper(self) -> None | Callable[[object], object]:
"""
Utility function used in `instantiate()` to check if a wrapper is specified in the config.

"""
wrapper = self.get_config().get(CONFIG_COMPONENT_KEY_WRAPPER, None)
if _wrapper_feature_flag.enabled:
if wrapper is not None:
if callable(wrapper):
return wrapper # type: ignore
else:
raise ValueError(
f"wrapper must be a callable, but got type {type(wrapper)}: {wrapper}."
"make sure all references are resolved before calling instantiate "
"and the wrapper is a callable."
)
elif wrapper is not None:
warnings.warn(
f"ConfigComponent: {self.get_id()} has a key {CONFIG_COMPONENT_KEY_WRAPPER}. "
"Since the feature flag CONFIG_WRAPPER is not enabled, the key will be treated as a normal config key. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in this case the wrapper will become None instead of a normal config key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not quite following, can you explain a bit further?

"In future versions of MONAI, this key might be reserved for the wrapper functionality."
)
return None

def instantiate(self, **kwargs: Any) -> object:
"""
Instantiate component based on ``self.config`` content.
The target component must be a `class` or a `function`, otherwise, return `None`.

Args:
kwargs: args to override / add the config args when instantiation.
kwargs: instantiate_kwargs to override / add the config instantiate_kwargs when instantiation.

"""
if not self.is_instantiable(self.get_config()) or self.is_disabled():
# if not a class or function or marked as `disabled`, skip parsing and return `None`
return None

modname = self.resolve_module_name()
mode = self.get_config().get("_mode_", CompInitMode.DEFAULT)
args = self.resolve_args()
args.update(kwargs)
return instantiate(modname, mode, **args)
mode = self.get_config().get(CONFIG_COMPONENT_KEY_MODE, CompInitMode.DEFAULT)
instantiate_kwargs = self.resolve_args()
instantiate_kwargs.update(kwargs)
wrapper = self._get_wrapper()
if wrapper is not None:
return wrapper(instantiate(modname, mode, **instantiate_kwargs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the wrapper here accept kwargs such as torch.compile can accept mode?
https://github.com/pytorch/pytorch/blob/4333e122d4b74cdf84351ed2907045c6a767b4cd/torch/compiler/__init__.py#L17

Copy link
Contributor Author

@johnzielke johnzielke May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the wrapper key should be "instantiated" like any other key, and therefore will accept kwargs. So by using mode: "callable" for torch.compile, you can bind any kwargs to it, e.g. pseudo-code:

_target_: nets.Unet
spatial_dim: 3
_wrapper_:
  _target_: torch.compile
  _mode_: callable
  dynamic: true
  fullgraph: true

So it doesn't accept them in this line, but should already have them bound by here.


try:
return instantiate(modname, mode, **instantiate_kwargs)
except Exception as e:
if _wrapper_feature_flag.enabled and self.get_id().endswith(CONFIG_COMPONENT_KEY_WRAPPER):
raise RuntimeError(
f"Failed to instantiate {self}. Make sure you are returning a partial "
f"(you might need to add {CONFIG_COMPONENT_KEY_MODE}:callable, "
"especially when using specifying a class)."
) from e
else:
# re-raise the exception if not using the wrapper
raise


class ConfigExpression(ConfigItem):
Expand Down
61 changes: 61 additions & 0 deletions monai/utils/feature_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import os
from contextlib import contextmanager

FEATURE_FLAG_PREFIX = "MONAI_FEATURE_ENABLED_"


class FeatureFlag:
johnzielke marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, name: str, *, default: bool = False):
self.name = name
self._enabled: bool | None = None
self.default = default

def _get_from_env(self):
return os.getenv(FEATURE_FLAG_PREFIX + self.name, None)

@property
def enabled(self) -> bool:
if self._enabled is None:
env = self._get_from_env()
if env is None:
self._enabled = self.default
else:
self._enabled = env.lower() in ["true", "1", "yes"]
return self._enabled

@enabled.setter
def enabled(self, value: bool) -> None:
self._enabled = value

def enable(self):
self.enabled = True

def disable(self):
self.enabled = False

def __str__(self):
return f"{self.name}: {self.enabled}, default: {self.default}"


@contextmanager
def with_feature_flag(feature_flag: FeatureFlag, enabled: bool): # type: ignore
original = feature_flag.enabled
feature_flag.enabled = enabled
try:
yield
finally:
feature_flag.enabled = original
9 changes: 9 additions & 0 deletions tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

import monai
from monai.bundle import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.config_item import _wrapper_feature_flag
from monai.data import DataLoader, Dataset
from monai.transforms import LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
from monai.utils.feature_flag import with_feature_flag

_, has_tv = optional_import("torchvision", "0.8.0", min_version)

Expand Down Expand Up @@ -133,6 +135,13 @@ def test_error_expr(self):
with self.assertRaisesRegex(RuntimeError, r"1\+\[\]"):
ConfigExpression(id="", config="$1+[]").evaluate()

def test_wrapper(self):
with with_feature_flag(_wrapper_feature_flag, True):
config = {"_target_": "fractions.Fraction", "numerator": 5, "denominator": 10, "_wrapper_": float}
configer = ConfigComponent(config=config, locator=None)
ret = configer.instantiate()
self.assertTrue(isinstance(ret, float))


if __name__ == "__main__":
unittest.main()
43 changes: 41 additions & 2 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
import tempfile
import unittest
import warnings
from collections import OrderedDict
from pathlib import Path
from unittest import mock, skipUnless

import numpy as np
from parameterized import parameterized

from monai.bundle import ConfigParser, ReferenceResolver
from monai.bundle.config_item import ConfigItem
from monai.data import DataLoader, Dataset
from monai.bundle.config_item import CONFIG_COMPONENT_KEY_WRAPPER, ConfigItem, _wrapper_feature_flag
from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
from monai.utils.feature_flag import with_feature_flag
from tests.utils import TimedCall

_, has_tv = optional_import("torchvision", "0.8.0", min_version)
Expand Down Expand Up @@ -124,6 +126,30 @@ def __call__(self, a, b):
1,
[0, 4],
]
TEST_CASE_WRAPPER_ENABLED = [
{
"dataset": {
"_target_": "Dataset",
"data": [1, 2],
CONFIG_COMPONENT_KEY_WRAPPER: {"_target_": "CacheDataset", "_mode_": "callable"},
}
},
["dataset", f"dataset#{CONFIG_COMPONENT_KEY_WRAPPER}"],
[CacheDataset, type(CacheDataset)],
johnzielke marked this conversation as resolved.
Show resolved Hide resolved
True,
]
TEST_CASE_WRAPPER_DISABLED = [
{
"dataset": {
"_target_": "collections.OrderedDict",
"data": [1, 2],
CONFIG_COMPONENT_KEY_WRAPPER: {"_target_": "CacheDataset", "_mode_": "callable"},
}
},
["dataset", f"dataset#{CONFIG_COMPONENT_KEY_WRAPPER}"],
[OrderedDict, type(CacheDataset)],
johnzielke marked this conversation as resolved.
Show resolved Hide resolved
False,
]


class TestConfigParser(unittest.TestCase):
Expand Down Expand Up @@ -357,6 +383,19 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex
self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val)
self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals)

@parameterized.expand([TEST_CASE_WRAPPER_ENABLED, TEST_CASE_WRAPPER_DISABLED])
def test_parse_wrapper(self, config, expected_ids, output_types, enable_feature_flag):
with with_feature_flag(_wrapper_feature_flag, enable_feature_flag):
parser = ConfigParser(config=config, globals={"monai": "monai", "torch": "torch"})
for id, cls in zip(expected_ids, output_types):
self.assertTrue(isinstance(parser.get_parsed_content(id), cls))
# test root content
root = parser.get_parsed_content(id="")
for v, cls in zip(root.values(), output_types):
self.assertTrue(isinstance(v, cls))
if not enable_feature_flag:
assert CONFIG_COMPONENT_KEY_WRAPPER in root["dataset"]


if __name__ == "__main__":
unittest.main()
Loading
Loading