Skip to content

Commit

Permalink
[DLMED] update config item
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <[email protected]>
  • Loading branch information
Nic-Ma committed Mar 10, 2022
1 parent c89e3a2 commit ca16681
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
43 changes: 22 additions & 21 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,21 @@ class ConfigComponent(ConfigItem, Instantiable):
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to
represent a component of `class` or `function` and supports instantiation.
Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals:
Currently, four special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
- class or function identifier of the python module, specified by one of the two keys.
- ``"<name>"``: indicates build-in python classes or functions such as "LoadImageDict".
- ``"<path>"``: full module name, such as "monai.transforms.LoadImageDict".
- ``"<args>"``: input arguments to the python module.
- ``"<disabled>"``: a flag to indicate whether to skip the instantiation.
- ``"_name_"``: indicates build-in python classes or functions such as "LoadImageDict".
- ``"_path_"``: full module name, such as "monai.transforms.LoadImageDict".
- ``"_disabled_"``: a flag to indicate whether to skip the instantiation.
Other fields in the config content are input arguments to the python module.
.. code-block:: python
locator = ComponentLocator(excludes=["modules_to_exclude"])
config = {
"<name>": "LoadImaged",
"<args>": {
"keys": ["image", "label"]
}
"_name_": "LoadImaged",
"keys": ["image", "label"]
}
configer = ConfigComponent(config, id="test", locator=locator)
Expand All @@ -195,6 +194,8 @@ class ConfigComponent(ConfigItem, Instantiable):
"""

not_arg_keys = ["_name_", "_path_", "_disabled_"]

def __init__(
self,
config: Any,
Expand All @@ -214,35 +215,35 @@ def is_instantiable(config: Any) -> bool:
config: input config content to check.
"""
return isinstance(config, Mapping) and ("<path>" in config or "<name>" in config)
return isinstance(config, Mapping) and ("_path_" in config or "_name_" in config)

def resolve_module_name(self):
"""
Resolve the target module name from current config content.
The config content must have ``"<path>"`` or ``"<name>"``.
When both are specified, ``"<path>"`` will be used.
The config content must have ``"_path_"`` or ``"_name_"`` key.
When both are specified, ``"_path_"`` will be used.
"""
config = dict(self.get_config())
path = config.get("<path>")
path = config.get("_path_")
if path is not None:
if not isinstance(path, str):
raise ValueError(f"'<path>' must be a string, but got: {path}.")
if "<name>" in config:
warnings.warn(f"both '<path>' and '<name>', default to use '<path>': {path}.")
raise ValueError(f"'_path_' must be a string, but got: {path}.")
if "_name_" in config:
warnings.warn(f"both '_path_' and '_name_', default to use '_path_': {path}.")
return path

name = config.get("<name>")
name = config.get("_name_")
if not isinstance(name, str):
raise ValueError("must provide a string for `<path>` or `<name>` of target component to instantiate.")
raise ValueError("must provide a string for `_path_` or `_name_` of target component to instantiate.")

module = self.locator.get_component_module_name(name)
if module is None:
raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.")
if isinstance(module, list):
warnings.warn(
f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}."
f" if want to use others, please set its module path in `<path>` directly."
f" if want to use others, please set its module path in `_path_` directly."
)
module = module[0]
return f"{module}.{name}"
Expand All @@ -252,14 +253,14 @@ def resolve_args(self):
Utility function used in `instantiate()` to resolve the arguments from current config content.
"""
return self.get_config().get("<args>", {})
return {k: v for k, v in self.get_config().items() if k not in self.not_arg_keys}

def is_disabled(self) -> bool: # type: ignore
"""
Utility function used in `instantiate()` to check whether to skip the instantiation.
"""
_is_disabled = self.get_config().get("<disabled>", False)
_is_disabled = self.get_config().get("_disabled_", False)
return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled)

def instantiate(self, **kwargs) -> object: # type: ignore
Expand Down
28 changes: 14 additions & 14 deletions tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@

TEST_CASE_1 = [{"lr": 0.001}, 0.0001]

TEST_CASE_2 = [{"<name>": "LoadImaged", "<args>": {"keys": ["image"]}}, LoadImaged]
# test python `<path>`
TEST_CASE_3 = [{"<path>": "monai.transforms.LoadImaged", "<args>": {"keys": ["image"]}}, LoadImaged]
# test `<disabled>`
TEST_CASE_4 = [{"<name>": "LoadImaged", "<disabled>": True, "<args>": {"keys": ["image"]}}, dict]
# test `<disabled>`
TEST_CASE_5 = [{"<name>": "LoadImaged", "<disabled>": "true", "<args>": {"keys": ["image"]}}, dict]
TEST_CASE_2 = [{"_name_": "LoadImaged", "keys": ["image"]}, LoadImaged]
# test python `_path_`
TEST_CASE_3 = [{"_path_": "monai.transforms.LoadImaged", "keys": ["image"]}, LoadImaged]
# test `_disabled_`
TEST_CASE_4 = [{"_name_": "LoadImaged", "_disabled_": True, "keys": ["image"]}, dict]
# test `_disabled_` with string
TEST_CASE_5 = [{"_name_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict]
# test non-monai modules and excludes
TEST_CASE_6 = [
{"<path>": "torch.optim.Adam", "<args>": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}},
{"_path_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4},
torch.optim.Adam,
]
TEST_CASE_7 = [{"<name>": "decollate_batch", "<args>": {"detach": True, "pad": True}}, partial]
TEST_CASE_7 = [{"_name_": "decollate_batch", "detach": True, "pad": True}, partial]
# test args contains "name" field
TEST_CASE_8 = [
{"<name>": "RandTorchVisiond", "<args>": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}},
{"_name_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
RandTorchVisiond,
]
# test execute some function in args, test pre-imported global packages `monai`
Expand All @@ -67,8 +67,8 @@ def test_component(self, test_input, output_type):
locator = ComponentLocator(excludes=["metrics"])
configer = ConfigComponent(id="test", config=test_input, locator=locator)
ret = configer.instantiate()
if test_input.get("<disabled>", False):
# test `<disabled>` works fine
if test_input.get("_disabled_", False):
# test `_disabled_` works fine
self.assertEqual(ret, None)
return
self.assertTrue(isinstance(ret, output_type))
Expand All @@ -83,11 +83,11 @@ def test_expression(self, id, test_input):
self.assertTrue(isinstance(ret, Callable))

def test_lazy_instantiation(self):
config = {"<name>": "DataLoader", "<args>": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}
config = {"_name_": "DataLoader", "dataset": Dataset(data=[1, 2]), "batch_size": 2}
configer = ConfigComponent(config=config, locator=None)
init_config = configer.get_config()
# modify config content at runtime
init_config["<args>"]["batch_size"] = 4
init_config["batch_size"] = 4
configer.update_config(config=init_config)

ret = configer.instantiate()
Expand Down

0 comments on commit ca16681

Please sign in to comment.