diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index be6c0caba6..8a3c721cfe 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -461,6 +461,10 @@ def load( device = "cuda:0" if is_available() else "cpu" if model_file is None: model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + if source == "ngc": + name = _add_ngc_prefix(name) + if remove_prefix: + name = _remove_ngc_prefix(name, prefix=remove_prefix) full_path = os.path.join(bundle_dir_, name, model_file) if not os.path.exists(full_path) or model is None: download( diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index f380626d73..f699914f6a 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -82,11 +82,15 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(os.path.exists(full_file_path)) self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) - weights = load( + model = load( name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix ) assert_allclose( - weights[TESTCASE_WEIGHTS["key"]], TESTCASE_WEIGHTS["value"], atol=1e-4, rtol=1e-4, type_test=False + model.state_dict()[TESTCASE_WEIGHTS["key"]], + TESTCASE_WEIGHTS["value"], + atol=1e-4, + rtol=1e-4, + type_test=False, )