diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index b6334ee9d5..9392f294f5 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -163,22 +163,21 @@ class ConfigComponent(ConfigItem, Instantiable): Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to represent a component of `class` or `function` and supports instantiation. - Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: + Currently, four special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: - class or function identifier of the python module, specified by one of the two keys. - - ``""``: indicates build-in python classes or functions such as "LoadImageDict". - - ``""``: full module name, such as "monai.transforms.LoadImageDict". - - ``""``: input arguments to the python module. - - ``""``: a flag to indicate whether to skip the instantiation. + - ``"_name_"``: indicates build-in python classes or functions such as "LoadImageDict". + - ``"_path_"``: full module name, such as "monai.transforms.LoadImageDict". + - ``"_disabled_"``: a flag to indicate whether to skip the instantiation. + + Other fields in the config content are input arguments to the python module. .. code-block:: python locator = ComponentLocator(excludes=["modules_to_exclude"]) config = { - "": "LoadImaged", - "": { - "keys": ["image", "label"] - } + "_name_": "LoadImaged", + "keys": ["image", "label"] } configer = ConfigComponent(config, id="test", locator=locator) @@ -195,6 +194,8 @@ class ConfigComponent(ConfigItem, Instantiable): """ + not_arg_keys = ["_name_", "_path_", "_disabled_"] + def __init__( self, config: Any, @@ -214,27 +215,27 @@ def is_instantiable(config: Any) -> bool: config: input config content to check. """ - return isinstance(config, Mapping) and ("" in config or "" in config) + return isinstance(config, Mapping) and ("_path_" in config or "_name_" in config) def resolve_module_name(self): """ Resolve the target module name from current config content. - The config content must have ``""`` or ``""``. - When both are specified, ``""`` will be used. + The config content must have ``"_path_"`` or ``"_name_"`` key. + When both are specified, ``"_path_"`` will be used. """ config = dict(self.get_config()) - path = config.get("") + path = config.get("_path_") if path is not None: if not isinstance(path, str): - raise ValueError(f"'' must be a string, but got: {path}.") - if "" in config: - warnings.warn(f"both '' and '', default to use '': {path}.") + raise ValueError(f"'_path_' must be a string, but got: {path}.") + if "_name_" in config: + warnings.warn(f"both '_path_' and '_name_', default to use '_path_': {path}.") return path - name = config.get("") + name = config.get("_name_") if not isinstance(name, str): - raise ValueError("must provide a string for `` or `` of target component to instantiate.") + raise ValueError("must provide a string for `_path_` or `_name_` of target component to instantiate.") module = self.locator.get_component_module_name(name) if module is None: @@ -242,7 +243,7 @@ def resolve_module_name(self): if isinstance(module, list): warnings.warn( f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." - f" if want to use others, please set its module path in `` directly." + f" if want to use others, please set its module path in `_path_` directly." ) module = module[0] return f"{module}.{name}" @@ -252,14 +253,14 @@ def resolve_args(self): Utility function used in `instantiate()` to resolve the arguments from current config content. """ - return self.get_config().get("", {}) + return {k: v for k, v in self.get_config().items() if k not in self.not_arg_keys} def is_disabled(self) -> bool: # type: ignore """ Utility function used in `instantiate()` to check whether to skip the instantiation. """ - _is_disabled = self.get_config().get("", False) + _is_disabled = self.get_config().get("_disabled_", False) return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) def instantiate(self, **kwargs) -> object: # type: ignore diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 1284efab56..9ce08561f3 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -26,22 +26,22 @@ TEST_CASE_1 = [{"lr": 0.001}, 0.0001] -TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test python `` -TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test `` -TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test `` -TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] +TEST_CASE_2 = [{"_name_": "LoadImaged", "keys": ["image"]}, LoadImaged] +# test python `_path_` +TEST_CASE_3 = [{"_path_": "monai.transforms.LoadImaged", "keys": ["image"]}, LoadImaged] +# test `_disabled_` +TEST_CASE_4 = [{"_name_": "LoadImaged", "_disabled_": True, "keys": ["image"]}, dict] +# test `_disabled_` with string +TEST_CASE_5 = [{"_name_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict] # test non-monai modules and excludes TEST_CASE_6 = [ - {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + {"_path_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam, ] -TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] +TEST_CASE_7 = [{"_name_": "decollate_batch", "detach": True, "pad": True}, partial] # test args contains "name" field TEST_CASE_8 = [ - {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + {"_name_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, RandTorchVisiond, ] # test execute some function in args, test pre-imported global packages `monai` @@ -67,8 +67,8 @@ def test_component(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) configer = ConfigComponent(id="test", config=test_input, locator=locator) ret = configer.instantiate() - if test_input.get("", False): - # test `` works fine + if test_input.get("_disabled_", False): + # test `_disabled_` works fine self.assertEqual(ret, None) return self.assertTrue(isinstance(ret, output_type)) @@ -83,11 +83,11 @@ def test_expression(self, id, test_input): self.assertTrue(isinstance(ret, Callable)) def test_lazy_instantiation(self): - config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} + config = {"_name_": "DataLoader", "dataset": Dataset(data=[1, 2]), "batch_size": 2} configer = ConfigComponent(config=config, locator=None) init_config = configer.get_config() # modify config content at runtime - init_config[""]["batch_size"] = 4 + init_config["batch_size"] = 4 configer.update_config(config=init_config) ret = configer.instantiate()