From bf9db2cada5e9cad8e9a1df7d070a2c3afac8626 Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Mon, 27 May 2024 20:17:14 +0800 Subject: [PATCH 01/10] Modify _log_input_summary function Fixes #7513 Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 598d938cbd..4718344bab 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -10,6 +10,7 @@ # limitations under the License. from __future__ import annotations +from io import StringIO import ast import json @@ -116,10 +117,20 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: def _log_input_summary(tag: str, args: dict) -> None: - logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") +<<<<<<< HEAD + log_buffer = StringIO() + log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") for name, val in args.items(): - logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}") - logger.info("---\n\n") + log_buffer.write(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\n") + log_buffer.write("---\n\n") +======= + log_buffer = StringIO() + log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") + for name, val in args.items(): + log_buffer.write(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\n") + log_buffer.write("---\n\n") +>>>>>>> 28f5ebe7e3a0487b9d78123a80d565f1fa26615a + logger.info(log_buffer.getvalue()) # output all content in buffer def _get_var_names(expr: str) -> list[str]: From 7f2633ec78cd272374f387cb28f08479e57a7f3a Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Mon, 27 May 2024 20:29:28 +0800 Subject: [PATCH 02/10] Modify "_log_input_summary" function Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4718344bab..fb86f96b30 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -117,20 +117,14 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: def _log_input_summary(tag: str, args: dict) -> None: -<<<<<<< HEAD - log_buffer = StringIO() + log_buffer = StringIO() + log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") for name, val in args.items(): log_buffer.write(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\n") log_buffer.write("---\n\n") -======= - log_buffer = StringIO() - log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") - for name, val in args.items(): - log_buffer.write(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\n") - log_buffer.write("---\n\n") ->>>>>>> 28f5ebe7e3a0487b9d78123a80d565f1fa26615a - logger.info(log_buffer.getvalue()) # output all content in buffer + + logger.info(log_buffer.getvalue()) def _get_var_names(expr: str) -> list[str]: From aa3d6ac320e5665208c28b4f370551930e55886a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 May 2024 12:44:40 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index fb86f96b30..ab9b205451 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -119,10 +119,10 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: def _log_input_summary(tag: str, args: dict) -> None: log_buffer = StringIO() - log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") + log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") for name, val in args.items(): log_buffer.write(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\n") - log_buffer.write("---\n\n") + log_buffer.write("---\n\n") logger.info(log_buffer.getvalue()) From 60b9947c7e4080404940afd94219549c8d51b70e Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Thu, 12 Sep 2024 17:22:42 +0800 Subject: [PATCH 04/10] modify _log_input_summary function. Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 67 +++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index ab9b205451..56c10c075c 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -10,7 +10,12 @@ # limitations under the License. from __future__ import annotations -from io import StringIO + +from datetime import datetime +import torch.distributed as dist +import threading +import logging.config +import configparser import ast import json @@ -115,17 +120,61 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: """ return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()]) +def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> None: + """ + Logs the input summary of a MONAI bundle script to the console and a local log file. + Each rank's logs are tagged with their rank number and saved to individual log files. + Reads the base log path from the logging.conf file and creates a separate log file for each rank. + Add log_all_ranks as a parameter to determine whether to log all ranks or only rank 0. + """ + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 + + if not log_all_ranks and rank != 0: + return + + log_lock = threading.Lock() + + with log_lock: + config_file = 'logging.conf' + config = configparser.ConfigParser() + config.read(config_file) + + if config.has_option('handler_fileHandler', 'args'): + base_log_dir = config.get('handler_fileHandler', 'args').strip("()").split(",")[0].strip().strip("'") + else: + base_log_dir = 'logs/' + + if not os.path.exists(base_log_dir): + os.makedirs(base_log_dir) + + log_file_path = os.path.join(base_log_dir, f'rank_{rank}_logfile.log') + logger = logging.getLogger(f"rank_{rank}_logger") + logger.setLevel(logging.INFO) + + if not logger.hasHandlers(): + file_handler = logging.FileHandler(log_file_path, mode='a') + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + logger = logging.LoggerAdapter(logger, {'rank': rank}) -def _log_input_summary(tag: str, args: dict) -> None: - log_buffer = StringIO() + formatted_args = {name: ", ".join(map(str, val)) if isinstance(val, (list, tuple, dict)) else val + for name, val in args.items()} - log_buffer.write(f"--- input summary of monai.bundle.scripts.{tag} ---\n") - for name, val in args.items(): - log_buffer.write(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\n") - log_buffer.write("---\n\n") + logger.info(f"--- Input summary of monai.bundle.scripts.{tag} ---") - logger.info(log_buffer.getvalue()) + for name, formatted_val in formatted_args.items(): + logger.info(f"> {name}: {formatted_val}") + logger.info(f"---\n\n") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + logger.info(f"Log written at {timestamp}") def _get_var_names(expr: str) -> list[str]: """ @@ -1808,4 +1857,4 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data.pop("hash_type") lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") - download_url(**lf_data) + download_url(**lf_data) \ No newline at end of file From 1fda340e7f0fc57be03509b1a80c134dfde03fe8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:40:05 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index e20663aa66..ec539d397b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -15,7 +15,7 @@ import torch.distributed as dist import threading import logging.config -import configparser +import configparser import ast import json @@ -57,7 +57,6 @@ get_equivalent_dtype, min_version, optional_import, - pprint_edges, ) validate, _ = optional_import("jsonschema", name="validate") @@ -132,11 +131,11 @@ def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> Non Reads the base log path from the logging.conf file and creates a separate log file for each rank. Add log_all_ranks as a parameter to determine whether to log all ranks or only rank 0. """ - is_distributed = dist.is_available() and dist.is_initialized() - rank = dist.get_rank() if is_distributed else 0 + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 if not log_all_ranks and rank != 0: - return + return log_lock = threading.Lock() @@ -177,7 +176,7 @@ def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> Non for name, formatted_val in formatted_args.items(): logger.info(f"> {name}: {formatted_val}") - logger.info(f"---\n\n") + logger.info("---\n\n") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") logger.info(f"Log written at {timestamp}") @@ -2063,4 +2062,4 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data.pop("hash_type") lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") - download_url(**lf_data) \ No newline at end of file + download_url(**lf_data) From 93ad7a5cab45975ddb82e865c40f07a511c394d4 Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Thu, 12 Sep 2024 18:02:36 +0800 Subject: [PATCH 06/10] fix: auto-fix imports in scripts.py Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index e20663aa66..aa0a87bb63 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,19 +11,17 @@ from __future__ import annotations -from datetime import datetime -import torch.distributed as dist -import threading -import logging.config -import configparser - import ast +import configparser import json +import logging.config import os import re +import threading import warnings import zipfile from collections.abc import Mapping, Sequence +from datetime import datetime from functools import partial from pathlib import Path from pydoc import locate @@ -32,6 +30,7 @@ from typing import Any, Callable import torch +import torch.distributed as dist from torch.cuda import is_available from monai._version import get_versions From c55e9c784e6053df82da50c869b6944c0162aa43 Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Thu, 12 Sep 2024 18:39:16 +0800 Subject: [PATCH 07/10] fix: auto-fix imports in scripts.py Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 366 +++++++++------------------------------- 1 file changed, 78 insertions(+), 288 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 78da592088..d5ce527989 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,15 +11,6 @@ from __future__ import annotations -<<<<<<< HEAD -======= -from datetime import datetime -import torch.distributed as dist -import threading -import logging.config -import configparser - ->>>>>>> 1fda340e7f0fc57be03509b1a80c134dfde03fe8 import ast import configparser import json @@ -28,10 +19,8 @@ import re import threading import warnings -import zipfile from collections.abc import Mapping, Sequence from datetime import datetime -from functools import partial from pathlib import Path from pydoc import locate from shutil import copyfile @@ -42,11 +31,11 @@ import torch.distributed as dist from torch.cuda import is_available -from monai._version import get_versions +from monai.apps.mmars.mmars import _get_all_ngc_models from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata @@ -65,6 +54,7 @@ get_equivalent_dtype, min_version, optional_import, + pprint_edges, ) validate, _ = optional_import("jsonschema", name="validate") @@ -81,9 +71,6 @@ DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting") PPRINT_CONFIG_N = 5 -MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" -NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit" - def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ @@ -118,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw if isinstance(v, dict) and isinstance(args_.get(k), dict): args_[k] = update_kwargs(args_[k], ignore_none, **v) else: - merge_kv(args_, k, v) + args_[k] = v return args_ @@ -132,6 +119,7 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: """ return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()]) + def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> None: """ Logs the input summary of a MONAI bundle script to the console and a local log file. @@ -148,45 +136,48 @@ def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> Non log_lock = threading.Lock() with log_lock: - config_file = 'logging.conf' - config = configparser.ConfigParser() - config.read(config_file) + config_file = "logging.conf" + config = configparser.ConfigParser() + config.read(config_file) + + if config.has_option("handler_fileHandler", "args"): + base_log_dir = config.get("handler_fileHandler", "args").strip("()").split(",")[0].strip().strip("'") + else: + base_log_dir = "logs/" - if config.has_option('handler_fileHandler', 'args'): - base_log_dir = config.get('handler_fileHandler', 'args').strip("()").split(",")[0].strip().strip("'") - else: - base_log_dir = 'logs/' + if not os.path.exists(base_log_dir): + os.makedirs(base_log_dir) - if not os.path.exists(base_log_dir): - os.makedirs(base_log_dir) + log_file_path = os.path.join(base_log_dir, f"rank_{rank}_logfile.log") + logger = logging.getLogger(f"rank_{rank}_logger") + logger.setLevel(logging.INFO) - log_file_path = os.path.join(base_log_dir, f'rank_{rank}_logfile.log') - logger = logging.getLogger(f"rank_{rank}_logger") - logger.setLevel(logging.INFO) + if not logger.hasHandlers(): + file_handler = logging.FileHandler(log_file_path, mode="a") + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) - if not logger.hasHandlers(): - file_handler = logging.FileHandler(log_file_path, mode='a') - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) + logger = logging.LoggerAdapter(logger, {"rank": rank}) - logger = logging.LoggerAdapter(logger, {'rank': rank}) + formatted_args = { + name: ", ".join(map(str, val)) if isinstance(val, (list, tuple, dict)) else val + for name, val in args.items() + } - formatted_args = {name: ", ".join(map(str, val)) if isinstance(val, (list, tuple, dict)) else val - for name, val in args.items()} + logger.info(f"--- Input summary of monai.bundle.scripts.{tag} ---") - logger.info(f"--- Input summary of monai.bundle.scripts.{tag} ---") + for name, formatted_val in formatted_args.items(): + logger.info(f"> {name}: {formatted_val}") - for name, formatted_val in formatted_args.items(): - logger.info(f"> {name}: {formatted_val}") + logger.info("---\n\n") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + logger.info(f"Log written at {timestamp}") - logger.info("---\n\n") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - logger.info(f"Log written at {timestamp}") def _get_var_names(expr: str) -> list[str]: """ @@ -234,19 +225,12 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" - - -def _get_ngc_private_base_url(repo: str) -> str: - return f"https://api.ngc.nvidia.com/v2/{repo}/models" - - -def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str: - return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip" + return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip" def _get_monaihosting_bundle_url(model_name: str, version: str) -> str: - return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" + monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" + return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: @@ -279,15 +263,10 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: def _download_from_ngc( - download_path: Path, - filename: str, - version: str, - prefix: str = "monai_", - remove_prefix: str | None = "monai_", - progress: bool = True, + download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool ) -> None: # ensure prefix is contained - filename = _add_ngc_prefix(filename, prefix=prefix) + filename = _add_ngc_prefix(filename) url = _get_ngc_bundle_url(model_name=filename, version=version) filepath = download_path / f"{filename}_v{version}.zip" if remove_prefix: @@ -297,175 +276,29 @@ def _download_from_ngc( extractall(filepath=filepath, output_dir=extract_path, has_base=True) -def _download_from_ngc_private( - download_path: Path, - filename: str, - version: str, - repo: str, - prefix: str = "monai_", - remove_prefix: str | None = "monai_", - headers: dict | None = None, -) -> None: - # ensure prefix is contained - filename = _add_ngc_prefix(filename, prefix=prefix) - request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo) - if has_requests: - headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) - response.raise_for_status() - else: - raise ValueError("NGC API requires requests package. Please install it.") - - os.makedirs(download_path, exist_ok=True) - zip_path = download_path / f"{filename}_v{version}.zip" - with open(zip_path, "wb") as f: - f.write(response.content) - logger.info(f"Downloading: {zip_path}.") - if remove_prefix: - filename = _remove_ngc_prefix(filename, prefix=remove_prefix) - extract_path = download_path / f"{filename}" - with zipfile.ZipFile(zip_path, "r") as z: - z.extractall(extract_path) - logger.info(f"Writing into directory: {extract_path}.") - - -def _get_ngc_token(api_key, retry=0): - """Try to connect to NGC.""" - url = "https://authn.nvidia.com/token?service=ngc" - headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key} - if has_requests: - response = requests_get(url, headers=headers) - if not response.ok: - # retry 3 times, if failed, raise an error. - if retry < 3: - logger.info(f"Retrying {retry} time(s) to GET {url}.") - return _get_ngc_token(url, retry + 1) - raise RuntimeError("NGC API response is not ok. Failed to get token.") - else: - token = response.json()["token"] - return token - - def _get_latest_bundle_version_monaihosting(name): - full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" + url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" + full_url = f"{url}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) resp.raise_for_status() else: - raise ValueError("NGC API requires requests package. Please install it.") + raise ValueError("NGC API requires requests package. Please install it.") model_info = json.loads(resp.text) return model_info["model"]["latestVersionIdStr"] -def _examine_monai_version(monai_version: str) -> tuple[bool, str]: - """Examine if the package version is compatible with the MONAI version in the metadata.""" - version_dict = get_versions() - package_version = version_dict.get("version", "0+unknown") - if package_version == "0+unknown": - return False, "Package version is not available. Skipping version check." - if monai_version == "0+unknown": - return False, "MONAI version is not specified in the bundle. Skipping version check." - # treat rc versions as the same as the release version - package_version = re.sub(r"rc\d.*", "", package_version) - monai_version = re.sub(r"rc\d.*", "", monai_version) - if package_version < monai_version: - return ( - False, - f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.", - ) - return True, "" - - -def _check_monai_version(bundle_dir: PathLike, name: str) -> None: - """Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version""" - metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json" - if not metadata_file.exists(): - logger.warning(f"metadata file not found in {metadata_file}.") - return - with open(metadata_file) as f: - metadata = json.load(f) - is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown")) - if not is_compatible: - logger.warning(msg) - - -def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]: - """ - Extract the latest versions from the data dictionary. - - Args: - data: the data dictionary. - max_versions: the maximum number of versions to return. - - Returns: - versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0']. - """ - # Check if the data is a dictionary and it has the key 'modelVersions' - if not isinstance(data, dict) or "modelVersions" not in data: - raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.") - - # Extract the list of model versions - model_versions = data["modelVersions"] - - if ( - not isinstance(model_versions, list) - or len(model_versions) == 0 - or "createdDate" not in model_versions[0] - or "versionId" not in model_versions[0] - ): - raise ValueError( - "The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'." - ) - - # Sort the versions by the 'createdDate' in descending order - sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True) - return [v["versionId"] for v in sorted_versions[:max_versions]] - - -def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str: - base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL - version_endpoint = base_url + f"/{name.lower()}/versions/" - - if not has_requests: - raise ValueError("requests package is required, please install it.") - - version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements - if headers: - version_header.update(headers) - resp = requests_get(version_endpoint, headers=version_header) - resp.raise_for_status() - model_info = json.loads(resp.text) - latest_versions = _list_latest_versions(model_info) - - for version in latest_versions: - file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" - resp = requests_get(file_endpoint, headers=headers) - metadata = json.loads(resp.text) - resp.raise_for_status() - # if the package version is not available or the model is compatible with the package version - is_compatible, _ = _examine_monai_version(metadata["monai_version"]) - if is_compatible: - if version != latest_versions[0]: - logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.") - return version - - # if no compatible version is found, return the latest version - return latest_versions[0] - - -def _get_latest_bundle_version( - source: str, name: str, repo: str, **kwargs: Any -) -> dict[str, list[str] | str] | Any | None: +def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) - return _get_latest_bundle_version_ngc(name) + model_dict = _get_all_ngc_models(name) + for v in model_dict.values(): + if v["name"] == name: + return v["latest"] + return None elif source == "monaihosting": return _get_latest_bundle_version_monaihosting(name) - elif source == "ngc_private": - headers = kwargs.pop("headers", {}) - name = _add_ngc_prefix(name) - return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers) elif source == "github": repo_owner, repo_name, tag_name = repo.split("/") return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"] @@ -532,9 +365,6 @@ def download( # Execute this module as a CLI entry, and download bundle via URL: python -m monai.bundle download --name --url - # Execute this module as a CLI entry, and download bundle from ngc_private with latest version: - python -m monai.bundle download --name --source "ngc_private" --bundle_dir "./" --repo "org/org_name" - # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime. # The content of the JSON / YAML file is a dictionary. For example: @@ -555,17 +385,14 @@ def download( Default is `bundle` subfolder under `torch.hub.get_dir()`. source: storage location name. This argument is used when `url` is `None`. In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and - it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub". - If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable. + it should be "ngc", "monaihosting", "github", or "huggingface_hub". repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub". If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". - If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name", - or you can specify the environment variable NGC_ORG and NGC_TEAM. url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. - remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles + remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to maintain the consistency between these two sources, remove prefix is necessary. Therefore, if specified, downloaded folder name will remove the prefix. @@ -593,18 +420,11 @@ def download( bundle_dir_ = _process_bundle_dir(bundle_dir_) if repo_ is None: - org_ = os.getenv("NGC_ORG", None) - team_ = os.getenv("NGC_TEAM", None) - if org_ is not None and source_ == "ngc_private": - repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}" - else: - repo_ = "Project-MONAI/model-zoo/hosting_storage_v1" - if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private": - raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.") - if len(repo_.split("/")) != 3 and source_ == "github": - raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.") + repo_ = "Project-MONAI/model-zoo/hosting_storage_v1" + if len(repo_.split("/")) != 3 and source_ != "huggingface_hub": + raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.") elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub": - raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.") + raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`") if url_ is not None: if name_ is not None: filepath = bundle_dir_ / f"{name_}.zip" @@ -613,22 +433,14 @@ def download( download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) else: - headers = {} if name_ is None: raise ValueError(f"To download from source: {source_}, `name` must be provided.") - if source == "ngc_private": - api_key = os.getenv("NGC_API_KEY", None) - if api_key is None: - raise ValueError("API key is required for ngc_private source.") - else: - token = _get_ngc_token(api_key) - headers = {"Authorization": f"Bearer {token}"} - if version_ is None: - version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers) + version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_) if source_ == "github": - name_ver = "_v".join([name_, version_]) if version_ is not None else name_ - _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) + if version_ is not None: + name_ = "_v".join([name_, version_]) + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) elif source_ == "monaihosting": _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) elif source_ == "ngc": @@ -639,15 +451,6 @@ def download( remove_prefix=remove_prefix_, progress=progress_, ) - elif source_ == "ngc_private": - _download_from_ngc_private( - download_path=bundle_dir_, - filename=name_, - version=version_, - remove_prefix=remove_prefix_, - repo=repo_, - headers=headers, - ) elif source_ == "huggingface_hub": extract_path = os.path.join(bundle_dir_, name_) huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path) @@ -657,8 +460,6 @@ def download( f"got source: {source_}." ) - _check_monai_version(bundle_dir_, name_) - @deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") @deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") @@ -1317,7 +1118,6 @@ def verify_net_in_out( def _export( converter: Callable, - saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1332,8 +1132,6 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. - saver: a callable object that accepts the converted model to save, a filepath to save to, meta values - (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1373,9 +1171,14 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - meta_values = parser.get().pop("_meta_", None) - saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files) - + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream=filepath, + include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None), + more_extra_files=extra_files, + ) logger.info(f"exported to file: {filepath}.") @@ -1474,23 +1277,17 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] + net = parser.get_parsed_content(net_id_) + if has_ignite: + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) + else: + ckpt = torch.load(ckpt_file_) + copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - - def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None: - onnx.save(onnx_obj, filename_prefix_or_stream) - - _export( - convert_to_onnx, - save_onnx, - parser, - net_id=net_id_, - filepath=filepath_, - ckpt_file=ckpt_file_, - config_file=config_file_, - key_in_ckpt=key_in_ckpt_, - **converter_kwargs_, - ) + onnx_model = convert_to_onnx(model=net, **converter_kwargs_) + onnx.save(onnx_model, filepath_) def ckpt_export( @@ -1611,12 +1408,8 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content - - save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) - _export( convert_to_torchscript, - save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1786,11 +1579,8 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) - save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) - _export( convert_to_trt, - save_ts, parser, net_id=net_id_, filepath=filepath_, From d914f4e4af658add116113b9dd139f1646ec2b94 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:40:54 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d5ce527989..423a36347a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -54,7 +54,6 @@ get_equivalent_dtype, min_version, optional_import, - pprint_edges, ) validate, _ = optional_import("jsonschema", name="validate") From 43ca23da5e86b8ea4996d387a55e6f71987fe407 Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Thu, 19 Sep 2024 16:48:40 +0800 Subject: [PATCH 09/10] Fix log_input_summary function Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 300 +++++++++++++++++++++++++++++++++------- 1 file changed, 250 insertions(+), 50 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d5ce527989..3e5f6970c0 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -19,8 +19,10 @@ import re import threading import warnings +import zipfile from collections.abc import Mapping, Sequence from datetime import datetime +from functools import partial from pathlib import Path from pydoc import locate from shutil import copyfile @@ -31,11 +33,11 @@ import torch.distributed as dist from torch.cuda import is_available -from monai.apps.mmars.mmars import _get_all_ngc_models +from monai._version import get_versions from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata @@ -71,6 +73,9 @@ DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting") PPRINT_CONFIG_N = 5 +MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" +NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit" + def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ @@ -105,7 +110,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw if isinstance(v, dict) and isinstance(args_.get(k), dict): args_[k] = update_kwargs(args_[k], ignore_none, **v) else: - args_[k] = v + merge_kv(args_, k, v) return args_ @@ -121,12 +126,6 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> None: - """ - Logs the input summary of a MONAI bundle script to the console and a local log file. - Each rank's logs are tagged with their rank number and saved to individual log files. - Reads the base log path from the logging.conf file and creates a separate log file for each rank. - Add log_all_ranks as a parameter to determine whether to log all ranks or only rank 0. - """ is_distributed = dist.is_available() and dist.is_initialized() rank = dist.get_rank() if is_distributed else 0 @@ -225,12 +224,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip" + return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" + + +def _get_ngc_private_base_url(repo: str) -> str: + return f"https://api.ngc.nvidia.com/v2/{repo}/models" + + +def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str: + return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip" def _get_monaihosting_bundle_url(model_name: str, version: str) -> str: - monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" + return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: @@ -263,10 +269,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: def _download_from_ngc( - download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool + download_path: Path, + filename: str, + version: str, + prefix: str = "monai_", + remove_prefix: str | None = "monai_", + progress: bool = True, ) -> None: # ensure prefix is contained - filename = _add_ngc_prefix(filename) + filename = _add_ngc_prefix(filename, prefix=prefix) url = _get_ngc_bundle_url(model_name=filename, version=version) filepath = download_path / f"{filename}_v{version}.zip" if remove_prefix: @@ -276,29 +287,175 @@ def _download_from_ngc( extractall(filepath=filepath, output_dir=extract_path, has_base=True) +def _download_from_ngc_private( + download_path: Path, + filename: str, + version: str, + repo: str, + prefix: str = "monai_", + remove_prefix: str | None = "monai_", + headers: dict | None = None, +) -> None: + # ensure prefix is contained + filename = _add_ngc_prefix(filename, prefix=prefix) + request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo) + if has_requests: + headers = {} if headers is None else headers + response = requests_get(request_url, headers=headers) + response.raise_for_status() + else: + raise ValueError("NGC API requires requests package. Please install it.") + + os.makedirs(download_path, exist_ok=True) + zip_path = download_path / f"{filename}_v{version}.zip" + with open(zip_path, "wb") as f: + f.write(response.content) + logger.info(f"Downloading: {zip_path}.") + if remove_prefix: + filename = _remove_ngc_prefix(filename, prefix=remove_prefix) + extract_path = download_path / f"{filename}" + with zipfile.ZipFile(zip_path, "r") as z: + z.extractall(extract_path) + logger.info(f"Writing into directory: {extract_path}.") + + +def _get_ngc_token(api_key, retry=0): + """Try to connect to NGC.""" + url = "https://authn.nvidia.com/token?service=ngc" + headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key} + if has_requests: + response = requests_get(url, headers=headers) + if not response.ok: + # retry 3 times, if failed, raise an error. + if retry < 3: + logger.info(f"Retrying {retry} time(s) to GET {url}.") + return _get_ngc_token(url, retry + 1) + raise RuntimeError("NGC API response is not ok. Failed to get token.") + else: + token = response.json()["token"] + return token + + def _get_latest_bundle_version_monaihosting(name): - url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name.lower()}" + full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) resp.raise_for_status() else: - raise ValueError("NGC API requires requests package. Please install it.") + raise ValueError("NGC API requires requests package. Please install it.") model_info = json.loads(resp.text) return model_info["model"]["latestVersionIdStr"] -def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None: +def _examine_monai_version(monai_version: str) -> tuple[bool, str]: + """Examine if the package version is compatible with the MONAI version in the metadata.""" + version_dict = get_versions() + package_version = version_dict.get("version", "0+unknown") + if package_version == "0+unknown": + return False, "Package version is not available. Skipping version check." + if monai_version == "0+unknown": + return False, "MONAI version is not specified in the bundle. Skipping version check." + # treat rc versions as the same as the release version + package_version = re.sub(r"rc\d.*", "", package_version) + monai_version = re.sub(r"rc\d.*", "", monai_version) + if package_version < monai_version: + return ( + False, + f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.", + ) + return True, "" + + +def _check_monai_version(bundle_dir: PathLike, name: str) -> None: + """Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version""" + metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json" + if not metadata_file.exists(): + logger.warning(f"metadata file not found in {metadata_file}.") + return + with open(metadata_file) as f: + metadata = json.load(f) + is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown")) + if not is_compatible: + logger.warning(msg) + + +def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]: + """ + Extract the latest versions from the data dictionary. + + Args: + data: the data dictionary. + max_versions: the maximum number of versions to return. + + Returns: + versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0']. + """ + # Check if the data is a dictionary and it has the key 'modelVersions' + if not isinstance(data, dict) or "modelVersions" not in data: + raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.") + + # Extract the list of model versions + model_versions = data["modelVersions"] + + if ( + not isinstance(model_versions, list) + or len(model_versions) == 0 + or "createdDate" not in model_versions[0] + or "versionId" not in model_versions[0] + ): + raise ValueError( + "The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'." + ) + + # Sort the versions by the 'createdDate' in descending order + sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True) + return [v["versionId"] for v in sorted_versions[:max_versions]] + + +def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str: + base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL + version_endpoint = base_url + f"/{name.lower()}/versions/" + + if not has_requests: + raise ValueError("requests package is required, please install it.") + + version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements + if headers: + version_header.update(headers) + resp = requests_get(version_endpoint, headers=version_header) + resp.raise_for_status() + model_info = json.loads(resp.text) + latest_versions = _list_latest_versions(model_info) + + for version in latest_versions: + file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" + resp = requests_get(file_endpoint, headers=headers) + metadata = json.loads(resp.text) + resp.raise_for_status() + # if the package version is not available or the model is compatible with the package version + is_compatible, _ = _examine_monai_version(metadata["monai_version"]) + if is_compatible: + if version != latest_versions[0]: + logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.") + return version + + # if no compatible version is found, return the latest version + return latest_versions[0] + + +def _get_latest_bundle_version( + source: str, name: str, repo: str, **kwargs: Any +) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) - model_dict = _get_all_ngc_models(name) - for v in model_dict.values(): - if v["name"] == name: - return v["latest"] - return None + return _get_latest_bundle_version_ngc(name) elif source == "monaihosting": return _get_latest_bundle_version_monaihosting(name) + elif source == "ngc_private": + headers = kwargs.pop("headers", {}) + name = _add_ngc_prefix(name) + return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers) elif source == "github": repo_owner, repo_name, tag_name = repo.split("/") return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"] @@ -365,6 +522,9 @@ def download( # Execute this module as a CLI entry, and download bundle via URL: python -m monai.bundle download --name --url + # Execute this module as a CLI entry, and download bundle from ngc_private with latest version: + python -m monai.bundle download --name --source "ngc_private" --bundle_dir "./" --repo "org/org_name" + # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime. # The content of the JSON / YAML file is a dictionary. For example: @@ -385,14 +545,17 @@ def download( Default is `bundle` subfolder under `torch.hub.get_dir()`. source: storage location name. This argument is used when `url` is `None`. In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and - it should be "ngc", "monaihosting", "github", or "huggingface_hub". + it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub". + If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable. repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub". If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". + If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name", + or you can specify the environment variable NGC_ORG and NGC_TEAM. url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. - remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles + remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to maintain the consistency between these two sources, remove prefix is necessary. Therefore, if specified, downloaded folder name will remove the prefix. @@ -420,11 +583,18 @@ def download( bundle_dir_ = _process_bundle_dir(bundle_dir_) if repo_ is None: - repo_ = "Project-MONAI/model-zoo/hosting_storage_v1" - if len(repo_.split("/")) != 3 and source_ != "huggingface_hub": - raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.") + org_ = os.getenv("NGC_ORG", None) + team_ = os.getenv("NGC_TEAM", None) + if org_ is not None and source_ == "ngc_private": + repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}" + else: + repo_ = "Project-MONAI/model-zoo/hosting_storage_v1" + if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private": + raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.") + if len(repo_.split("/")) != 3 and source_ == "github": + raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.") elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub": - raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`") + raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.") if url_ is not None: if name_ is not None: filepath = bundle_dir_ / f"{name_}.zip" @@ -433,14 +603,22 @@ def download( download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) else: + headers = {} if name_ is None: raise ValueError(f"To download from source: {source_}, `name` must be provided.") + if source == "ngc_private": + api_key = os.getenv("NGC_API_KEY", None) + if api_key is None: + raise ValueError("API key is required for ngc_private source.") + else: + token = _get_ngc_token(api_key) + headers = {"Authorization": f"Bearer {token}"} + if version_ is None: - version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_) + version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers) if source_ == "github": - if version_ is not None: - name_ = "_v".join([name_, version_]) - _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) + name_ver = "_v".join([name_, version_]) if version_ is not None else name_ + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) elif source_ == "ngc": @@ -451,6 +629,15 @@ def download( remove_prefix=remove_prefix_, progress=progress_, ) + elif source_ == "ngc_private": + _download_from_ngc_private( + download_path=bundle_dir_, + filename=name_, + version=version_, + remove_prefix=remove_prefix_, + repo=repo_, + headers=headers, + ) elif source_ == "huggingface_hub": extract_path = os.path.join(bundle_dir_, name_) huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path) @@ -460,6 +647,8 @@ def download( f"got source: {source_}." ) + _check_monai_version(bundle_dir_, name_) + @deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") @deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") @@ -1118,6 +1307,7 @@ def verify_net_in_out( def _export( converter: Callable, + saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1132,6 +1322,8 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. + saver: a callable object that accepts the converted model to save, a filepath to save to, meta values + (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1171,14 +1363,9 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, - ) + meta_values = parser.get().pop("_meta_", None) + saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files) + logger.info(f"exported to file: {filepath}.") @@ -1277,17 +1464,23 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None: + onnx.save(onnx_obj, filename_prefix_or_stream) + + _export( + convert_to_onnx, + save_onnx, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) def ckpt_export( @@ -1408,8 +1601,12 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content + + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_torchscript, + save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1579,8 +1776,11 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_trt, + save_ts, parser, net_id=net_id_, filepath=filepath_, From 69001d76f8d8e0270f9b80e2b4f80977e787beef Mon Sep 17 00:00:00 2001 From: "Ted.Lai" Date: Mon, 23 Sep 2024 16:45:13 +0800 Subject: [PATCH 10/10] Fix log_input_summary Signed-off-by: Ted.Lai --- monai/bundle/scripts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 7bbec7692c..054207d0e5 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -23,6 +23,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from functools import partial +from logging import LoggerAdapter from pathlib import Path from pydoc import locate from shutil import copyfile @@ -160,7 +161,7 @@ def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> Non console_handler.setFormatter(formatter) logger.addHandler(console_handler) - logger = logging.LoggerAdapter(logger, {"rank": rank}) + logger: LoggerAdapter = logging.LoggerAdapter(logger, {"rank": rank}) formatted_args = { name: ", ".join(map(str, val)) if isinstance(val, (list, tuple, dict)) else val