diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index f31ad0e814..054207d0e5 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -12,13 +12,18 @@ from __future__ import annotations import ast +import configparser import json +import logging.config import os import re +import threading import warnings import zipfile from collections.abc import Mapping, Sequence +from datetime import datetime from functools import partial +from logging import LoggerAdapter from pathlib import Path from pydoc import locate from shutil import copyfile @@ -26,6 +31,7 @@ from typing import Any, Callable import torch +import torch.distributed as dist from torch.cuda import is_available from monai._version import get_versions @@ -51,7 +57,6 @@ get_equivalent_dtype, min_version, optional_import, - pprint_edges, ) validate, _ = optional_import("jsonschema", name="validate") @@ -120,11 +125,57 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()]) -def _log_input_summary(tag: str, args: dict) -> None: - logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") - for name, val in args.items(): - logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}") - logger.info("---\n\n") +def _log_input_summary(tag: str, args: dict, log_all_ranks: bool = False) -> None: + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 + + if not log_all_ranks and rank != 0: + return + + log_lock = threading.Lock() + + with log_lock: + config_file = "logging.conf" + config = configparser.ConfigParser() + config.read(config_file) + + if config.has_option("handler_fileHandler", "args"): + base_log_dir = config.get("handler_fileHandler", "args").strip("()").split(",")[0].strip().strip("'") + else: + base_log_dir = "logs/" + + if not os.path.exists(base_log_dir): + os.makedirs(base_log_dir) + + log_file_path = os.path.join(base_log_dir, f"rank_{rank}_logfile.log") + logger = logging.getLogger(f"rank_{rank}_logger") + logger.setLevel(logging.INFO) + + if not logger.hasHandlers(): + file_handler = logging.FileHandler(log_file_path, mode="a") + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + logger: LoggerAdapter = logging.LoggerAdapter(logger, {"rank": rank}) + + formatted_args = { + name: ", ".join(map(str, val)) if isinstance(val, (list, tuple, dict)) else val + for name, val in args.items() + } + + logger.info(f"--- Input summary of monai.bundle.scripts.{tag} ---") + + for name, formatted_val in formatted_args.items(): + logger.info(f"> {name}: {formatted_val}") + + logger.info("---\n\n") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + logger.info(f"Log written at {timestamp}") def _get_var_names(expr: str) -> list[str]: