diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index dd556e9eb3..711f1d2875 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -19,6 +19,7 @@ ckpt_export, create_workflow, download, + download_large_files, get_all_bundles_list, get_bundle_info, get_bundle_versions, diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index e2a78bac5e..778c9ef2f0 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -14,6 +14,7 @@ from monai.bundle.scripts import ( ckpt_export, download, + download_large_files, init_bundle, onnx_export, run, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index fc8dafbc77..056a4ddef3 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1619,3 +1619,43 @@ def create_workflow( workflow_.initialize() return workflow_ + + +def download_large_files(bundle_path: str | None = None, large_file_name: str | None = None) -> None: + """ + This utility allows you to download large files from a bundle. It supports file suffixes like ".yml", ".yaml", and ".json". + If you don't specify a `large_file_name`, it will automatically search for large files among the supported suffixes. + + Typical usage examples: + .. code-block:: bash + + # Execute this module as a CLI entry to download large files from a bundle path: + python -m monai.bundle download_large_files --bundle_path + + # Execute this module as a CLI entry to download large files from the bundle path with a specified `large_file_name`: + python -m monai.bundle download_large_files --bundle_path --large_file_name large_files.yaml + + Args: + bundle_path: (Optional) The path to the bundle where the files are located. Default is `os.getcwd()`. + large_file_name: (Optional) The name of the large file to be downloaded. + + """ + bundle_path = os.getcwd() if bundle_path is None else bundle_path + if large_file_name is None: + large_file_path = list(Path(bundle_path).glob("large_files*")) + large_file_path = list(filter(lambda x: x.suffix in [".yml", ".yaml", ".json"], large_file_path)) + if len(large_file_path) == 0: + raise FileNotFoundError(f"Cannot find the large_files.yml/yaml/json under {bundle_path}.") + + parser = ConfigParser() + parser.read_config(large_file_path) + large_files_list = parser.get()["large_files"] + for lf_data in large_files_list: + lf_data["fuzzy"] = True + if "hash_val" in lf_data and lf_data.get("hash_val", "") == "": + lf_data.pop("hash_val") + if "hash_type" in lf_data and lf_data.get("hash_type", "") == "": + lf_data.pop("hash_type") + lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) + lf_data.pop("path") + download_url(**lf_data) diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index d43cf3b9c0..3c78112bfa 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -71,6 +71,13 @@ {"spatial_dims": 3, "out_channels": 5}, ] +TEST_CASE_8 = [ + ["network.json", "test_output.pt", "test_input.pt", "large_files.yaml"], + "test_bundle", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.2.zip", + {"model.pt": "27952767e2e154e3b0ee65defc5aed38", "model.ts": "97746870fe591f69ac09827175b00675"}, +] + class TestDownload(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @@ -148,7 +155,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) device=device, return_state_dict=True, ) - # prepare network with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f: net_args = json.load(f)["network_def"] @@ -275,5 +281,33 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, self.assertTrue("network.json" in extra_file_dict.keys()) +class TestDownloadLargefiles(unittest.TestCase): + @parameterized.expand([TEST_CASE_8]) + @skip_if_quick + def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val): + with skip_if_downloading_fails(): + # download a single file from url, also use `args_file` + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} + def_args_file = os.path.join(tempdir, "def_args.json") + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] + cmd += ["--url", url, "--source", "github"] + command_line_tests(cmd) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + print(file_path) + self.assertTrue(os.path.exists(file_path)) + + # download large files + bundle_path = os.path.join(tempdir, bundle_name) + cmd = ["coverage", "run", "-m", "monai.bundle", "download_large_files", "--bundle_path", bundle_path] + command_line_tests(cmd) + for file in ["model.pt", "model.ts"]: + file_path = os.path.join(tempdir, bundle_name, f"models/{file}") + self.assertTrue(check_hash(filepath=file_path, val=hash_val[file])) + + if __name__ == "__main__": unittest.main()