diff --git a/docs/requirements.txt b/docs/requirements.txt index dcf5ef5b2a..c2d0f22dcb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -25,3 +25,5 @@ mlflow tensorboardX imagecodecs; platform_system == "Linux" tifffile; platform_system == "Linux" +pyyaml +fire diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 03d4e07d17..22260d822f 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -32,3 +32,7 @@ Model Bundle --------------- .. autoclass:: ConfigParser :members: + +`Scripts` +--------- +.. autofunction:: run diff --git a/docs/source/installation.md b/docs/source/installation.md index 15c372c385..29cf1eab66 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -190,10 +190,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, -`tifffile`, `imagecodecs`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/environment-dev.yml b/environment-dev.yml index ae41f21f1f..4491f87ceb 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -42,6 +42,8 @@ dependencies: - transformers - mlflow - tensorboardX + - pyyaml + - fire - pip - pip: # pip for itk as conda-forge version only up to v5.1 diff --git a/monai/__init__.py b/monai/__init__.py index a823a3e1e2..e56a2f3444 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -39,7 +39,7 @@ # handlers_* have some external decorators the users may not have installed # *.so files and folder "_C" may not exist when the cpp extensions are not compiled -excludes = "(^(monai.handlers))|((\\.so)$)|(^(monai._C))" +excludes = "(^(monai.handlers))|(^(monai.bundle))|((\\.so)$)|(^(monai._C))" # load directory modules only, skip loading individual files load_submodules(sys.modules[__name__], False, exclude_pattern=excludes) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 68e2d543bb..b411406e84 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,3 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver +from .scripts import run +from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py new file mode 100644 index 0000000000..7a87030bec --- /dev/null +++ b/monai/bundle/__main__.py @@ -0,0 +1,19 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from monai.bundle.scripts import run + +if __name__ == "__main__": + from monai.utils import optional_import + + fire, _ = optional_import("fire") + fire.Fire() diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 44cdd3c634..807b369f5d 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -17,6 +17,7 @@ from importlib import import_module from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +from monai.bundle.utils import EXPR_KEY from monai.utils import ensure_tuple, instantiate __all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] @@ -160,25 +161,24 @@ def __repr__(self) -> str: class ConfigComponent(ConfigItem, Instantiable): """ - Subclass of :py:class:`monai.apps.ConfigItem`, this class uses a dictionary with string keys to + Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to represent a component of `class` or `function` and supports instantiation. - Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: + Currently, two special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: - class or function identifier of the python module, specified by one of the two keys. - - ``""``: indicates build-in python classes or functions such as "LoadImageDict". - - ``""``: full module name, such as "monai.transforms.LoadImageDict". - - ``""``: input arguments to the python module. - - ``""``: a flag to indicate whether to skip the instantiation. + - ``"_target_"``: indicates build-in python classes or functions such as "LoadImageDict", + or full module name, such as "monai.transforms.LoadImageDict". + - ``"_disabled_"``: a flag to indicate whether to skip the instantiation. + + Other fields in the config content are input arguments to the python module. .. code-block:: python locator = ComponentLocator(excludes=["modules_to_exclude"]) config = { - "": "LoadImaged", - "": { - "keys": ["image", "label"] - } + "_target_": "LoadImaged", + "keys": ["image", "label"] } configer = ConfigComponent(config, id="test", locator=locator) @@ -195,6 +195,8 @@ class ConfigComponent(ConfigItem, Instantiable): """ + non_arg_keys = {"_target_", "_disabled_"} + def __init__( self, config: Any, @@ -214,52 +216,45 @@ def is_instantiable(config: Any) -> bool: config: input config content to check. """ - return isinstance(config, Mapping) and ("" in config or "" in config) + return isinstance(config, Mapping) and "_target_" in config def resolve_module_name(self): """ Resolve the target module name from current config content. - The config content must have ``""`` or ``""``. - When both are specified, ``""`` will be used. + The config content must have ``"_target_"`` key. """ config = dict(self.get_config()) - path = config.get("") - if path is not None: - if not isinstance(path, str): - raise ValueError(f"'' must be a string, but got: {path}.") - if "" in config: - warnings.warn(f"both '' and '', default to use '': {path}.") - return path - - name = config.get("") - if not isinstance(name, str): - raise ValueError("must provide a string for `` or `` of target component to instantiate.") + target = config.get("_target_") + if not isinstance(target, str): + raise ValueError("must provide a string for the `_target_` of component to instantiate.") - module = self.locator.get_component_module_name(name) + module = self.locator.get_component_module_name(target) if module is None: - raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") + # target is the full module name, no need to parse + return target + if isinstance(module, list): warnings.warn( - f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." - f" if want to use others, please set its module path in `` directly." + f"there are more than 1 component have name `{target}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its full module path in `_target_` directly." ) module = module[0] - return f"{module}.{name}" + return f"{module}.{target}" def resolve_args(self): """ Utility function used in `instantiate()` to resolve the arguments from current config content. """ - return self.get_config().get("", {}) + return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys} def is_disabled(self) -> bool: # type: ignore """ Utility function used in `instantiate()` to check whether to skip the instantiation. """ - _is_disabled = self.get_config().get("", False) + _is_disabled = self.get_config().get("_disabled_", False) return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) def instantiate(self, **kwargs) -> object: # type: ignore @@ -283,7 +278,7 @@ def instantiate(self, **kwargs) -> object: # type: ignore class ConfigExpression(ConfigItem): """ - Subclass of :py:class:`monai.apps.ConfigItem`, the `ConfigItem` represents an executable expression + Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression (execute based on ``eval()``). See also: @@ -308,7 +303,7 @@ class ConfigExpression(ConfigItem): """ - prefix = "$" + prefix = EXPR_KEY run_eval = False if os.environ.get("MONAI_EVAL_EXPR", "1") == "0" else True def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 5ebcfd03b4..6fa7b3a2a2 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -10,11 +10,21 @@ # limitations under the License. import importlib +import json +import re from copy import deepcopy -from typing import Any, Dict, Optional, Sequence, Union +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Tuple, Union from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver +from monai.bundle.utils import ID_SEP_KEY, MACRO_KEY +from monai.config import PathLike +from monai.utils import ensure_tuple, look_up_option, optional_import + +yaml, _ = optional_import("yaml") + +__all__ = ["ConfigParser"] class ConfigParser: @@ -30,24 +40,22 @@ class ConfigParser: .. code-block:: python - from monai.apps import ConfigParser + from monai.bundle import ConfigParser config = { "my_dims": 2, "dims_1": "$@my_dims + 1", - "my_xform": {"": "LoadImage"}, - "my_net": {"": "BasicUNet", - "": {"spatial_dims": "@dims_1", "in_channels": 1, "out_channels": 4}}, - "trainer": {"": "SupervisedTrainer", - "": {"network": "@my_net", "preprocessing": "@my_xform"}} + "my_xform": {"_target_": "LoadImage"}, + "my_net": {"_target_": "BasicUNet", "spatial_dims": "@dims_1", "in_channels": 1, "out_channels": 4}, + "trainer": {"_target_": "SupervisedTrainer", "network": "@my_net", "preprocessing": "@my_xform"} } # in the example $@my_dims + 1 is an expression, which adds 1 to the value of @my_dims parser = ConfigParser(config) # get/set configuration content, the set method should happen before calling parse() - print(parser["my_net"][""]["in_channels"]) # original input channels 1 - parser["my_net"][""]["in_channels"] = 4 # change input channels to 4 - print(parser["my_net"][""]["in_channels"]) + print(parser["my_net"]["in_channels"]) # original input channels 1 + parser["my_net"]["in_channels"] = 4 # change input channels to 4 + print(parser["my_net"]["in_channels"]) # instantiate the network component parser.parse(True) @@ -70,13 +78,19 @@ class ConfigParser: See also: - - :py:class:`monai.apps.ConfigItem` + - :py:class:`monai.bundle.ConfigItem` + - :py:class:`monai.bundle.scripts.run` """ + suffixes = ("json", "yaml", "yml") + suffix_match = rf".*\.({'|'.join(suffixes)})" + path_match = rf"({suffix_match}$)" + meta_key = "_meta_" # field key to save metadata + def __init__( self, - config: Any, + config: Any = None, excludes: Optional[Union[Sequence[str], str]] = None, globals: Optional[Dict[str, Any]] = None, ): @@ -89,6 +103,8 @@ def __init__( self.locator = ComponentLocator(excludes=excludes) self.ref_resolver = ReferenceResolver() + if config is None: + config = {self.meta_key: {}} self.set(config=config) def __repr__(self): @@ -102,7 +118,7 @@ def __getitem__(self, id: Union[str, int]): id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to go one level further into the nested structures. Use digits indexing from "0" for list or other strings for dict. - For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. """ if id == "": @@ -124,7 +140,7 @@ def __setitem__(self, id: Union[str, int], config: Any): id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to go one level further into the nested structures. Use digits indexing from "0" for list or other strings for dict. - For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. config: config to set at location ``id``. """ @@ -162,6 +178,100 @@ def set(self, config: Any, id: str = ""): """ self[id] = config + def parse(self, reset: bool = True): + """ + Recursively resolve `self.config` to replace the macro tokens with target content. + Then recursively parse the config source, add every item as ``ConfigItem`` to the reference resolver. + + Args: + reset: whether to reset the ``reference_resolver`` before parsing. Defaults to `True`. + + """ + if reset: + self.ref_resolver.reset() + self.resolve_macro() + self._do_parse(config=self.get()) + + def get_parsed_content(self, id: str = "", **kwargs): + """ + Get the parsed result of ``ConfigItem`` with the specified ``id``. + + - If the item is ``ConfigComponent`` and ``instantiate=True``, the result is the instance. + - If the item is ``ConfigExpression`` and ``eval_expr=True``, the result is the evaluated output. + - Else, the result is the configuration content of `ConfigItem`. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + kwargs: additional keyword arguments to be passed to ``_resolve_one_item``. + Currently support ``reset`` (for parse), ``instantiate`` and ``eval_expr``. All defaulting to True. + + """ + if not self.ref_resolver.is_resolved(): + # not parsed the config source yet, parse it + self.parse(kwargs.get("reset", True)) + return self.ref_resolver.get_resolved_content(id=id, **kwargs) + + def read_meta(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): + """ + Read the metadata from specified JSON or YAML file. + The metadata as a dictionary will be stored at ``self.config["_meta_"]``. + + Args: + f: filepath of the metadata file, the content must be a dictionary, + if providing a list of files, wil merge the content of them. + if providing a dictionary directly, use it as metadata. + kwargs: other arguments for ``json.load`` or ``yaml.safe_load``, depends on the file format. + + """ + self.set(self.load_config_files(f, **kwargs), self.meta_key) + + def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): + """ + Read the config from specified JSON or YAML file. + The config content in the `self.config` dictionary. + + Args: + f: filepath of the config file, the content must be a dictionary, + if providing a list of files, wil merge the content of them. + if providing a dictionary directly, use it as config. + kwargs: other arguments for ``json.load`` or ``yaml.safe_load``, depends on the file format. + + """ + content = {self.meta_key: self.get(self.meta_key, {})} + content.update(self.load_config_files(f, **kwargs)) + self.set(config=content) + + def _do_resolve(self, config: Any): + """ + Recursively resolve the config content to replace the macro tokens with target content. + The macro tokens start with "%", can be from another structured file, like: + ``{"net": "%default_net"}``, ``{"net": "%/data/config.json#net"}``. + + Args: + config: input config file to resolve. + + """ + if isinstance(config, (dict, list)): + for k, v in enumerate(config) if isinstance(config, list) else config.items(): + config[k] = self._do_resolve(v) + if isinstance(config, str) and config.startswith(MACRO_KEY): + path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) + parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) + return self._do_resolve(config=deepcopy(parser[ids])) + return config + + def resolve_macro(self): + """ + Recursively resolve `self.config` to replace the macro tokens with target content. + The macro tokens are marked as starting with "%", can be from another structured file, like: + ``"%default_net"``, ``"%/data/config.json#net"``. + + """ + self.set(self._do_resolve(config=deepcopy(self.get()))) + def _do_parse(self, config, id: str = ""): """ Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver. @@ -171,7 +281,7 @@ def _do_parse(self, config, id: str = ""): id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to go one level further into the nested structures. Use digits indexing from "0" for list or other strings for dict. - For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. """ if isinstance(config, (dict, list)): @@ -189,36 +299,77 @@ def _do_parse(self, config, id: str = ""): else: self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) - def parse(self, reset: bool = True): + @classmethod + def load_config_file(cls, filepath: PathLike, **kwargs): """ - Recursively parse the config source, add every item as ``ConfigItem`` to the resolver. + Load config file with specified file path (currently support JSON and YAML files). Args: - reset: whether to reset the ``reference_resolver`` before parsing. Defaults to `True`. + filepath: path of target file to load, supported postfixes: `.json`, `.yml`, `.yaml`. + kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format. """ - if reset: - self.ref_resolver.reset() - self._do_parse(config=self.config) + _filepath: str = str(Path(filepath)) + if not re.compile(cls.path_match, re.IGNORECASE).findall(_filepath): + raise ValueError(f'unknown file input: "{filepath}"') + with open(_filepath) as f: + if _filepath.lower().endswith(cls.suffixes[0]): + return json.load(f, **kwargs) + if _filepath.lower().endswith(cls.suffixes[1:]): + return yaml.safe_load(f, **kwargs) + raise ValueError(f"only support JSON or YAML config file so far, got name {_filepath}.") + + @classmethod + def load_config_files(cls, files: Union[PathLike, Sequence[PathLike], dict], **kwargs) -> dict: + """ + Load config files into a single config dict. - def get_parsed_content(self, id: str = "", **kwargs): + Args: + files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`. + kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format. """ - Get the parsed result of ``ConfigItem`` with the specified ``id``. + if isinstance(files, dict): # already a config dict + return files + content = {} + for i in ensure_tuple(files): + content.update(cls.load_config_file(i, **kwargs)) + return content + + @classmethod + def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwargs): + """ + Export the config content to the specified file path (currently support JSON and YAML files). - - If the item is ``ConfigComponent`` and ``instantiate=True``, the result is the instance. - - If the item is ``ConfigExpression`` and ``eval_expr=True``, the result is the evaluated output. - - Else, the result is the configuration content of `ConfigItem`. + Args: + config: source config content to export. + filepath: target file path to save. + fmt: format of config content, currently support ``"json"`` and ``"yaml"``. + kwargs: other arguments for ``json.dump`` or ``yaml.safe_dump``, depends on the file format. + + """ + _filepath: str = str(Path(filepath)) + writer = look_up_option(fmt.lower(), {"json", "yaml"}) + with open(_filepath, "w") as f: + if writer == "json": + return json.dump(config, f, **kwargs) + if writer == "yaml": + return yaml.safe_dump(config, f, **kwargs) + raise ValueError(f"only support JSON or YAML config file so far, got {writer}.") + + @classmethod + def split_path_id(cls, src: str) -> Tuple[str, str]: + """ + Split `src` string into two parts: a config file path and component id. + The file path should end with `(json|yaml|yml)`. The component id should be separated by `#` if it exists. + If no path or no id, return "". Args: - id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to - go one level further into the nested structures. - Use digits indexing from "0" for list or other strings for dict. - For example: ``"xform#5"``, ``"net##channels"``. ``""`` indicates the entire ``self.config``. - kwargs: additional keyword arguments to be passed to ``_resolve_one_item``. - Currently support ``reset`` (for parse), ``instantiate`` and ``eval_expr``. All defaulting to True. + src: source string to split. """ - if not self.ref_resolver.is_resolved(): - # not parsed the config source yet, parse it - self.parse(kwargs.get("reset", True)) - return self.ref_resolver.get_resolved_content(id=id, **kwargs) + result = re.compile(rf"({cls.suffix_match}(?=(?:{ID_SEP_KEY}.*)|$))", re.IGNORECASE).findall(src) + if not result: + return "", src # the src is a pure id + path_name = result[0][0] # at most one path_name + _, ids = src.rsplit(path_name, 1) + return path_name, ids[len(ID_SEP_KEY) :] if ids.startswith(ID_SEP_KEY) else "" diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index 45d897af05..c1599c2124 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -13,8 +13,11 @@ from typing import Any, Dict, Optional, Sequence, Set from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY from monai.utils import look_up_option +__all__ = ["ReferenceResolver"] + class ReferenceResolver: """ @@ -43,10 +46,10 @@ class ReferenceResolver: """ _vars = "__local_refs" - sep = "#" # separator for key indexing - ref = "@" # reference prefix - # match a reference string, e.g. "@id#key", "@id#key#0", "@##key" - id_matcher = re.compile(rf"{ref}(?:(?:<\w*>)|(?:\w*))(?:(?:{sep}<\w*>)|(?:{sep}\w*))*") + sep = ID_SEP_KEY # separator for key indexing + ref = ID_REF_KEY # reference prefix + # match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key" + id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*") def __init__(self, items: Optional[Sequence[ConfigItem]] = None): # save the items in a dictionary with the `ConfigItem.id` as key @@ -257,6 +260,9 @@ def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): sub_id = f"{id}{cls.sep}{idx}" if id != "" else f"{idx}" if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): updated = refs_[sub_id] + if ConfigComponent.is_instantiable(v) and updated is None: + # the component is disabled + continue else: updated = cls.update_config_with_refs(v, sub_id, refs_) ret.update({idx: updated}) if isinstance(ret, dict) else ret.append(updated) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py new file mode 100644 index 0000000000..ebfd3e54ac --- /dev/null +++ b/monai/bundle/scripts.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pprint +from typing import Dict, Optional, Sequence, Union + +from monai.bundle.config_parser import ConfigParser + + +def _update_default_args(args: Optional[Union[str, Dict]] = None, **kwargs) -> Dict: + """ + Update the `args` with the input `kwargs`. + For dict data, recursively update the content based on the keys. + + Args: + args: source args to update. + kwargs: destination args to update. + + """ + args_: Dict = args if isinstance(args, dict) else {} # type: ignore + if isinstance(args, str): + # args are defined in a structured file + args_ = ConfigParser.load_config_file(args) + + # recursively update the default args with new args + for k, v in kwargs.items(): + args_[k] = _update_default_args(args_[k], **v) if isinstance(v, dict) and isinstance(args_.get(k), dict) else v + return args_ + + +def run( + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + target_id: Optional[str] = None, + args_file: Optional[str] = None, + **override, +): + """ + Specify `meta_file` and `config_file` to run monai bundle components and workflows. + + Typical usage examples: + + .. code-block:: bash + + # Execute this module as a CLI entry: + python -m monai.bundle run --meta_file --config_file --target_id trainer + + # Override config values at runtime by specifying the component id and its new value: + python -m monai.bundle run --net#input_chns 1 ... + + # Override config values with another config file `/path/to/another.json`: + python -m monai.bundle run --net %/path/to/another.json ... + + # Override config values with part content of another config file: + python -m monai.bundle run --net %/data/other.json#net_arg ... + + # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. + # Other args still can override the default args at runtime: + python -m monai.bundle run --args_file "/workspace/data/args.json" --config_file + + Args: + meta_file: filepath of the metadata file, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + target_id: ID name of the target component or workflow, it must have a `run` method. + args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, + `target_id` and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--net#input_chns 42``. + + """ + k_v = zip(["meta_file", "config_file", "target_id"], [meta_file, config_file, target_id]) + for k, v in k_v: + if v is not None: + override[k] = v + + full_kv = zip( + ("meta_file", "config_file", "target_id", "args_file", "override"), + (meta_file, config_file, target_id, args_file, override), + ) + print("\n--- input summary of monai.bundle.scripts.run ---") + for name, val in full_kv: + print(f"> {name}: {pprint.pformat(val)}") + print("---\n\n") + + _args = _update_default_args(args=args_file, **override) + for k in ("meta_file", "config_file"): + if k not in _args: + raise ValueError(f"{k} is required for 'monai.bundle run'.\n{run.__doc__}") + + parser = ConfigParser() + parser.read_config(f=_args.pop("config_file")) + parser.read_meta(f=_args.pop("meta_file")) + id = _args.pop("target_id", "") + + # the rest key-values in the args are to override config content + for k, v in _args.items(): + parser[k] = v + + workflow = parser.get_parsed_content(id=id) + if not hasattr(workflow, "run"): + raise ValueError(f"The parsed workflow {type(workflow)} does not have a `run` method.\n{run.__doc__}") + workflow.run() diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py new file mode 100644 index 0000000000..ba5c2729e7 --- /dev/null +++ b/monai/bundle/utils.py @@ -0,0 +1,18 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY"] + + +ID_REF_KEY = "@" # start of a reference to a ConfigItem +ID_SEP_KEY = "#" # separator for the ID of a ConfigItem +EXPR_KEY = "$" # start of a ConfigExpression +MACRO_KEY = "%" # start of a macro of a config diff --git a/monai/data/samplers.py b/monai/data/samplers.py index f5175266d8..40eed03187 100644 --- a/monai/data/samplers.py +++ b/monai/data/samplers.py @@ -50,7 +50,7 @@ def __init__( super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, **kwargs) if not even_divisible: - data_len = len(dataset) + data_len = len(dataset) # type: ignore extra_size = self.total_size - data_len if self.rank + extra_size >= self.num_replicas: self.num_samples -= 1 diff --git a/monai/utils/module.py b/monai/utils/module.py index de2152d182..065cc8f7c8 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -212,7 +212,7 @@ def instantiate(path: str, **kwargs): component = locate(path) if component is None: - raise ModuleNotFoundError(f"Cannot locate '{path}'.") + raise ModuleNotFoundError(f"Cannot locate class or function path: '{path}'.") if isclass(component): return component(**kwargs) # support regular function, static method and class method diff --git a/requirements-dev.txt b/requirements-dev.txt index eaf363fbe4..2b3786d1f3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -43,3 +43,6 @@ transformers mlflow matplotlib!=3.5.0 tensorboardX +types-PyYAML +pyyaml +fire diff --git a/setup.cfg b/setup.cfg index a9cfa09ccc..aa5eae07a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,8 @@ all = mlflow matplotlib tensorboardX + pyyaml + fire nibabel = nibabel skimage = @@ -92,6 +94,10 @@ matplotlib = matplotlib tensorboardX = tensorboardX +pyyaml = + pyyaml +fire = + fire [flake8] select = B,C,E,F,N,P,T4,W,B9 @@ -106,7 +112,7 @@ ignore = W504 C408 N812 # lowercase 'torch.nn.functional' imported as non lowercase 'F' -per_file_ignores = __init__.py: F401 +per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py [isort] diff --git a/tests/min_tests.py b/tests/min_tests.py index e0710a93ec..8f01ee1826 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -159,6 +159,7 @@ def run_testsuit(): "test_zoomd", "test_prepare_batch_default_dist", "test_parallel_execution_dist", + "test_bundle_run", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_bundle_run.py b/tests/test_bundle_run.py new file mode 100644 index 0000000000..75002d3631 --- /dev/null +++ b/tests/test_bundle_run.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import subprocess +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.bundle import ConfigParser +from monai.transforms import LoadImage + +TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), (128, 128, 128)] + +TEST_CASE_2 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.yaml"), (128, 128, 128)] + + +class TestBundleRun(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, config_file, expected_shape): + test_image = np.random.rand(*expected_shape) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "image.nii") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) + + # generate default args in a JSON file + def_args = {"config_file": "will be replaced by `config_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.json") + ConfigParser.export_config_file(config=def_args, filepath=def_args_file) + + meta = {"datalist": [{"image": filename}], "output_dir": tempdir, "window": (96, 96, 96)} + # test YAML file + meta_file = os.path.join(tempdir, "meta.yaml") + ConfigParser.export_config_file(config=meta, filepath=meta_file, fmt="yaml") + + # test override with file, up case postfix + overridefile1 = os.path.join(tempdir, "override1.JSON") + with open(overridefile1, "w") as f: + # test override with part of the overriding file + json.dump({"move_net": "$@network_def.to(@device)"}, f) + os.makedirs(os.path.join(tempdir, "jsons"), exist_ok=True) + overridefile2 = os.path.join(tempdir, "jsons/override2.JSON") + with open(overridefile2, "w") as f: + # test override with the whole overriding file + json.dump("Dataset", f) + + saver = LoadImage(image_only=True) + + if sys.platform == "win32": + override = "--network $@network_def.to(@device) --dataset#_target_ Dataset" + else: + override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" + # test with `monai.bundle` as CLI entry directly + cmd = "-m monai.bundle run --target_id evaluator" + cmd += f" --postprocessing#transforms#2#output_postfix seg {override}" + la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + ret = subprocess.check_call(la + ["--args_file", def_args_file]) + self.assertEqual(ret, 0) + self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape) + + # here test the script with `google fire` tool as CLI + cmd = "-m fire monai.bundle.scripts run --target_id evaluator" + cmd += f" --evaluator#amp False {override}" + la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + ret = subprocess.check_call(la) + self.assertEqual(ret, 0) + self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 1284efab56..7b43cd30ea 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -26,22 +26,19 @@ TEST_CASE_1 = [{"lr": 0.001}, 0.0001] -TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test python `` -TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test `` -TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test `` -TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] +TEST_CASE_2 = [{"_target_": "LoadImaged", "keys": ["image"]}, LoadImaged] +# test full module path +TEST_CASE_3 = [{"_target_": "monai.transforms.LoadImaged", "keys": ["image"]}, LoadImaged] +# test `_disabled_` +TEST_CASE_4 = [{"_target_": "LoadImaged", "_disabled_": True, "keys": ["image"]}, dict] +# test `_disabled_` with string +TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict] # test non-monai modules and excludes -TEST_CASE_6 = [ - {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, -] -TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] +TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam] +TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True}, partial] # test args contains "name" field TEST_CASE_8 = [ - {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, RandTorchVisiond, ] # test execute some function in args, test pre-imported global packages `monai` @@ -67,8 +64,8 @@ def test_component(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) configer = ConfigComponent(id="test", config=test_input, locator=locator) ret = configer.instantiate() - if test_input.get("", False): - # test `` works fine + if test_input.get("_disabled_", False): + # test `_disabled_` works fine self.assertEqual(ret, None) return self.assertTrue(isinstance(ret, output_type)) @@ -83,11 +80,11 @@ def test_expression(self, id, test_input): self.assertTrue(isinstance(ret, Callable)) def test_lazy_instantiation(self): - config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} + config = {"_target_": "DataLoader", "dataset": Dataset(data=[1, 2]), "batch_size": 2} configer = ConfigComponent(config=config, locator=None) init_config = configer.get_config() # modify config content at runtime - init_config[""]["batch_size"] = 4 + init_config["batch_size"] = 4 configer.update_config(config=init_config) ret = configer.instantiate() diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 5b5aa2b816..ce98be1214 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -25,24 +25,21 @@ TEST_CASE_1 = [ { "transform": { - "": "Compose", - "": { - "transforms": [ - {"": "LoadImaged", "": {"keys": "image"}}, - { - "": "RandTorchVisiond", - "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, - }, - ] - }, + "_target_": "Compose", + "transforms": [ + {"_target_": "LoadImaged", "keys": "image"}, + {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, + ], }, - "dataset": {"": "Dataset", "": {"data": [1, 2], "transform": "@transform"}}, + "dataset": {"_target_": "Dataset", "data": [1, 2], "transform": "@transform"}, "dataloader": { - "": "DataLoader", - "": {"dataset": "@dataset", "batch_size": 2, "collate_fn": "monai.data.list_data_collate"}, + "_target_": "DataLoader", + "dataset": "@dataset", + "batch_size": 2, + "collate_fn": "monai.data.list_data_collate", }, }, - ["transform", "transform##transforms#0", "transform##transforms#1", "dataset", "dataloader"], + ["transform", "transform#transforms#0", "transform#transforms#1", "dataset", "dataloader"], [Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader], ] @@ -67,9 +64,9 @@ def __call__(self, a, b): "cls_func": "$TestClass.cls_compute", "lambda_static_func": "$lambda x, y: TestClass.compute(x, y)", "lambda_cls_func": "$lambda x, y: TestClass.cls_compute(x, y)", - "compute": {"": "tests.test_config_parser.TestClass.compute", "": {"func": "@basic_func"}}, - "cls_compute": {"": "tests.test_config_parser.TestClass.cls_compute", "": {"func": "@basic_func"}}, - "call_compute": {"": "tests.test_config_parser.TestClass"}, + "compute": {"_target_": "tests.test_config_parser.TestClass.compute", "func": "@basic_func"}, + "cls_compute": {"_target_": "tests.test_config_parser.TestClass.cls_compute", "func": "@basic_func"}, + "call_compute": {"_target_": "tests.test_config_parser.TestClass"}, "error_func": "$TestClass.__call__", "": "$lambda x, y: x + y", } @@ -78,17 +75,17 @@ def __call__(self, a, b): class TestConfigComponent(unittest.TestCase): def test_config_content(self): - test_config = {"preprocessing": [{"": "LoadImage"}], "dataset": {"": "Dataset"}} + test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) # test `get`, `set`, `__getitem__`, `__setitem__` self.assertEqual(str(parser.get()), str(test_config)) parser.set(config=test_config) self.assertListEqual(parser["preprocessing"], test_config["preprocessing"]) - parser["dataset"] = {"": "CacheDataset"} - self.assertEqual(parser["dataset"][""], "CacheDataset") + parser["dataset"] = {"_target_": "CacheDataset"} + self.assertEqual(parser["dataset"]["_target_"], "CacheDataset") # test nested ids - parser["dataset#"] = "Dataset" - self.assertEqual(parser["dataset#"], "Dataset") + parser["dataset#_target_"] = "Dataset" + self.assertEqual(parser["dataset#_target_"], "Dataset") # test int id parser.set(["test1", "test2", "test3"]) parser[1] = "test4" @@ -99,11 +96,11 @@ def test_config_content(self): def test_parse(self, config, expected_ids, output_types): parser = ConfigParser(config=config, globals={"monai": "monai"}) # test lazy instantiation with original config content - parser["transform"][""]["transforms"][0][""]["keys"] = "label1" - self.assertEqual(parser.get_parsed_content(id="transform##transforms#0").keys[0], "label1") + parser["transform"]["transforms"][0]["keys"] = "label1" + self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1") # test nested id - parser["transform##transforms#0##keys"] = "label2" - self.assertEqual(parser.get_parsed_content(id="transform##transforms#0").keys[0], "label2") + parser["transform#transforms#0#keys"] = "label2" + self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label2") for id, cls in zip(expected_ids, output_types): self.assertTrue(isinstance(parser.get_parsed_content(id), cls)) # test root content diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py index e16a795c40..e6b01c05f4 100644 --- a/tests/test_reference_resolver.py +++ b/tests/test_reference_resolver.py @@ -27,11 +27,10 @@ TEST_CASE_1 = [ { # all the recursively parsed config items - "transform#1": {"": "LoadImaged", "": {"keys": ["image"]}}, - "transform#1#": "LoadImaged", - "transform#1#": {"keys": ["image"]}, - "transform#1##keys": ["image"], - "transform#1##keys#0": "image", + "transform#1": {"_target_": "LoadImaged", "keys": ["image"]}, + "transform#1#_target_": "LoadImaged", + "transform#1#keys": ["image"], + "transform#1#keys#0": "image", }, "transform#1", LoadImaged, @@ -40,20 +39,15 @@ TEST_CASE_2 = [ { # some the recursively parsed config items - "dataloader": { - "": "DataLoader", - "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, - }, - "dataset": {"": "Dataset", "": {"data": [1, 2]}}, - "dataloader#": "DataLoader", - "dataloader#": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, - "dataloader##dataset": "@dataset", - "dataloader##collate_fn": "$monai.data.list_data_collate", - "dataset#": "Dataset", - "dataset#": {"data": [1, 2]}, - "dataset##data": [1, 2], - "dataset##data#0": 1, - "dataset##data#1": 2, + "dataloader": {"_target_": "DataLoader", "dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, + "dataset": {"_target_": "Dataset", "data": [1, 2]}, + "dataloader#_target_": "DataLoader", + "dataloader#dataset": "@dataset", + "dataloader#collate_fn": "$monai.data.list_data_collate", + "dataset#_target_": "Dataset", + "dataset#data": [1, 2], + "dataset#data#0": 1, + "dataset#data#1": 2, }, "dataloader", DataLoader, @@ -62,15 +56,11 @@ TEST_CASE_3 = [ { # all the recursively parsed config items - "transform#1": { - "": "RandTorchVisiond", - "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, - }, - "transform#1#": "RandTorchVisiond", - "transform#1#": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, - "transform#1##keys": "image", - "transform#1##name": "ColorJitter", - "transform#1##brightness": 0.25, + "transform#1": {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, + "transform#1#_target_": "RandTorchVisiond", + "transform#1#keys": "image", + "transform#1#name": "ColorJitter", + "transform#1#brightness": 0.25, }, "transform#1", RandTorchVisiond, @@ -97,7 +87,7 @@ def test_resolve(self, configs, expected_id, output_type): # test lazy instantiation item = resolver.get_item(expected_id, resolve=True) config = item.get_config() - config[""] = False + config["_disabled_"] = False item.update_config(config=config) if isinstance(item, ConfigComponent): result = item.instantiate() diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json new file mode 100644 index 0000000000..b96968496d --- /dev/null +++ b/tests/testing_data/inference.json @@ -0,0 +1,103 @@ +{ + "device": "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')", + "network_def": { + "_target_": "UNet", + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 2, + "channels": [ + 16, + 32, + 64, + 128, + 256 + ], + "strides": [ + 2, + 2, + 2, + 2 + ], + "num_res_units": 2, + "norm": "batch" + }, + "network": "need override", + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "LoadImaged", + "keys": "image" + }, + { + "_target_": "EnsureChannelFirstd", + "keys": "image" + }, + { + "_target_": "ScaleIntensityd", + "keys": "image" + }, + { + "_target_": "RandRotated", + "_disabled_": true, + "keys": "image" + }, + { + "_target_": "EnsureTyped", + "keys": "image" + } + ] + }, + "dataset": { + "_target_": "need override", + "data": "@_meta_#datalist", + "transform": "@preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 4 + }, + "inferer": { + "_target_": "SlidingWindowInferer", + "roi_size": [ + 96, + 96, + 96 + ], + "sw_batch_size": 4, + "overlap": 0.5 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + }, + { + "_target_": "AsDiscreted", + "keys": "pred", + "argmax": true + }, + { + "_target_": "SaveImaged", + "keys": "pred", + "meta_keys": "image_meta_dict", + "output_dir": "@_meta_#output_dir" + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@dataloader", + "network": "@network", + "inferer": "@inferer", + "postprocessing": "@postprocessing", + "amp": false + } +} diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml new file mode 100644 index 0000000000..58eeca8191 --- /dev/null +++ b/tests/testing_data/inference.yaml @@ -0,0 +1,74 @@ +--- +device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')" +network_def: + _target_: UNet + spatial_dims: 3 + in_channels: 1 + out_channels: 2 + channels: + - 16 + - 32 + - 64 + - 128 + - 256 + strides: + - 2 + - 2 + - 2 + - 2 + num_res_units: 2 + norm: batch +network: need override +preprocessing: + _target_: Compose + transforms: + - _target_: LoadImaged + keys: image + - _target_: EnsureChannelFirstd + keys: image + - _target_: ScaleIntensityd + keys: image + - _target_: RandRotated + _disabled_: true + keys: image + - _target_: EnsureTyped + keys: image +dataset: + _target_: need override + data: "@_meta_#datalist" + transform: "@preprocessing" +dataloader: + _target_: DataLoader + dataset: "@dataset" + batch_size: 1 + shuffle: false + num_workers: 4 +inferer: + _target_: SlidingWindowInferer + roi_size: + - 96 + - 96 + - 96 + sw_batch_size: 4 + overlap: 0.5 +postprocessing: + _target_: Compose + transforms: + - _target_: Activationsd + keys: pred + softmax: true + - _target_: AsDiscreted + keys: pred + argmax: true + - _target_: SaveImaged + keys: pred + meta_keys: image_meta_dict + output_dir: "@_meta_#output_dir" +evaluator: + _target_: SupervisedEvaluator + device: "@device" + val_data_loader: "@dataloader" + network: "@network" + inferer: "@inferer" + postprocessing: "@postprocessing" + amp: false