From 6f13b8d0054efd522f6f30322b22fc77e0d8cec7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:06:19 +0800 Subject: [PATCH] Avoid breaking change in creating `BundleWorkflow` (#6950) Fixes # . ### Description Avoid breaking changes introduced by https://github.com/Project-MONAI/MONAI/pull/6835 - when creating `BundleWorkflow` - when using `load` API, add `return_state_dict` when `model` and `net_name` are both `None`. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/scripts.py | 77 +++++++++++++++++++++++++---------- monai/bundle/workflows.py | 13 +++++- tests/ngc_bundle_download.py | 7 +++- tests/test_bundle_download.py | 31 ++++++++++++-- 4 files changed, 101 insertions(+), 27 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index cdea2b4218..fc8dafbc77 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -28,6 +28,7 @@ from monai.apps.mmars.mmars import _get_all_ngc_models from monai.apps.utils import _basename, download_url, extractall, get_logger +from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow @@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: return Path(bundle_dir) -@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.4") +@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.5") def download( name: str | None = None, version: str | None = None, @@ -375,8 +376,9 @@ def download( ) -@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.") -@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.") +@deprecated_arg("net_name", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.") +@deprecated_arg("net_kwargs", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.") +@deprecated_arg("return_state_dict", since="1.3", removed="1.5") def load( name: str, model: torch.nn.Module | None = None, @@ -395,8 +397,10 @@ def load( workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, copy_model_args: dict | None = None, + return_state_dict: bool = True, + net_override: dict | None = None, net_name: str | None = None, - **net_override: Any, + **net_kwargs: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. @@ -441,7 +445,12 @@ def load( workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". args_file: a JSON or YAML file to provide default values for all the args in "download" function. copy_model_args: other arguments for the `monai.networks.copy_model_state` function. - net_override: id-value pairs to override the parameters in the network of the bundle. + return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network + from `_workflow.network_def` will be instantiated and load the achieved weights. + net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`. + net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. + This argument only works when loading weights. + net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. Returns: 1. If `load_ts_module` is `False` and `model` is `None`, @@ -452,9 +461,15 @@ def load( 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. please check `monai.data.load_net_with_metadata` for more details. + 4. If `return_state_dict` is True, return model weights, only used for compatibility + when `model` and `net_name` are all `None`. """ + if return_state_dict and (model is not None or net_name is not None): + warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.") + bundle_dir_ = _process_bundle_dir(bundle_dir) + net_override = {} if net_override is None else net_override copy_model_args = {} if copy_model_args is None else copy_model_args if device is None: @@ -466,7 +481,7 @@ def load( if remove_prefix: name = _remove_ngc_prefix(name, prefix=remove_prefix) full_path = os.path.join(bundle_dir_, name, model_file) - if not os.path.exists(full_path) or model is None: + if not os.path.exists(full_path): download( name=name, version=version, @@ -477,34 +492,52 @@ def load( progress=progress, args_file=args_file, ) - train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" - if train_config_file.is_file(): - _net_override = {f"network_def#{key}": value for key, value in net_override.items()} - _workflow = create_workflow( - workflow_name=workflow_name, - args_file=args_file, - config_file=str(train_config_file), - workflow_type=workflow_type, - **_net_override, - ) - else: - _workflow = None # loading with `torch.jit.load` if load_ts_module is True: return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) # loading with `torch.load` model_dict = torch.load(full_path, map_location=torch.device(device)) + if not isinstance(model_dict, Mapping): warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") model_dict = get_state_dict(model_dict) - if model is None and _workflow is None: + if return_state_dict: return model_dict - model = _workflow.network_def if model is None else model - model.to(device) - copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args) + _workflow = None + if model is None and net_name is None: + bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" + if bundle_config_file.is_file(): + _net_override = {f"network_def#{key}": value for key, value in net_override.items()} + _workflow = create_workflow( + workflow_name=workflow_name, + args_file=args_file, + config_file=str(bundle_config_file), + workflow_type=workflow_type, + **_net_override, + ) + else: + warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.") + return model_dict + if _workflow is not None: + if not hasattr(_workflow, "network_def"): + warnings.warn("No available network definition in the bundle, return state dict instead.") + return model_dict + else: + model = _workflow.network_def + elif net_name is not None: + net_kwargs["_target_"] = net_name + configer = ConfigComponent(config=net_kwargs) + model = configer.instantiate() # type: ignore + + model.to(device) # type: ignore + + copy_model_state( + dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args # type: ignore + ) + return model diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 8d53f2e88c..5f34578f7b 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -43,6 +43,10 @@ class BundleWorkflow(ABC): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. """ @@ -56,7 +60,8 @@ class BundleWorkflow(ABC): new_name="workflow_type", msg_suffix="please use `workflow_type` instead.", ) - def __init__(self, workflow_type: str | None = None): + def __init__(self, workflow_type: str | None = None, workflow: str | None = None): + workflow_type = workflow if workflow is not None else workflow_type if workflow_type is None: self.properties = copy(MetaProperties) self.workflow_type = None @@ -198,6 +203,10 @@ class ConfigWorkflow(BundleWorkflow): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` @@ -221,8 +230,10 @@ def __init__( final_id: str = "finalize", tracking: str | dict | None = None, workflow_type: str | None = None, + workflow: str | None = None, **override: Any, ) -> None: + workflow_type = workflow if workflow is not None else workflow_type super().__init__(workflow_type=workflow_type) if config_file is not None: _config_files = ensure_tuple(config_file) diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index f699914f6a..ba35f2b80c 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -83,7 +83,12 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) model = load( - name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix + name=bundle_name, + source="ngc", + version=version, + bundle_dir=tempdir, + remove_prefix=remove_prefix, + return_state_dict=False, ) assert_allclose( model.state_dict()[TESTCASE_WEIGHTS["key"]], diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 2457af3229..d43cf3b9c0 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -146,6 +146,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) source="github", progress=False, device=device, + return_state_dict=True, ) # prepare network @@ -174,13 +175,29 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) bundle_dir=tempdir, progress=False, device=device, - net_name=model_name, source="github", + return_state_dict=False, ) model_2.eval() output_2 = model_2.forward(input_tensor) assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + # test compatibility with return_state_dict=True. + model_3 = load( + name=bundle_name, + model_file=model_file, + bundle_dir=tempdir, + progress=False, + device=device, + net_name=model_name, + source="github", + return_state_dict=False, + **net_args, + ) + model_3.eval() + output_3 = model_3.forward(input_tensor) + assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + @parameterized.expand([TEST_CASE_7]) @skip_if_quick def test_load_weights_with_net_override(self, bundle_name, device, net_override): @@ -188,7 +205,14 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device) + model = load( + name=bundle_name, + bundle_dir=tempdir, + source="monaihosting", + progress=False, + device=device, + return_state_dict=False, + ) # prepare data and test input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) @@ -209,7 +233,8 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) source="monaihosting", progress=False, device=device, - **net_override, + return_state_dict=False, + net_override=net_override, ) # prepare data and test