Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6387 update_kwargs for merging multiple configs #7109

Merged
merged 9 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ Model Bundle
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
.. autofunction:: init_bundle
.. autofunction:: update_kwargs
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
run,
run_workflow,
trt_export,
update_kwargs,
verify_metadata,
verify_net_in_out,
)
Expand Down
5 changes: 4 additions & 1 deletion monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,16 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs

Args:
files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`.
if providing a list of files, wil merge the content of them.
if providing a list of files, will merge the content of them.
if providing a string with comma separated file paths, will merge the content of them.
if providing a dictionary, return it directly.
kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format.
"""
if isinstance(files, dict): # already a config dict
return files
parser = ConfigParser(config={})
if isinstance(files, str) and not Path(files).is_file() and "," in files:
files = files.split(",")
for i in ensure_tuple(files):
for k, v in (cls.load_config_file(i, **kwargs)).items():
parser[k] = v
Expand Down
39 changes: 26 additions & 13 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,46 @@
PPRINT_CONFIG_N = 5


def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Update the `args` with the input `kwargs`.
Update the `args` dictionary with the input `kwargs`.
For dict data, recursively update the content based on the keys.

Example::

from monai.bundle import update_kwargs
update_kwargs({'exist': 1}, exist=2, new_arg=3)
# return {'exist': 2, 'new_arg': 3}

Args:
args: source args to update.
args: source `args` dictionary (or a json/yaml filename to read as dictionary) to update.
ignore_none: whether to ignore input args with None value, default to `True`.
kwargs: destination args to update.
kwargs: key=value pairs to be merged into `args`.

"""
args_: dict = args if isinstance(args, dict) else {}
if isinstance(args, str):
# args are defined in a structured file
args_ = ConfigParser.load_config_file(args)
if isinstance(args, (tuple, list)) and all(isinstance(x, str) for x in args):
primary, overrides = args
args_ = update_kwargs(primary, ignore_none, **update_kwargs(overrides, ignore_none, **kwargs))
if not isinstance(args_, dict):
return args_
# recursively update the default args with new args
for k, v in kwargs.items():
print(k, v)
if ignore_none and v is None:
continue
if isinstance(v, dict) and isinstance(args_.get(k), dict):
args_[k] = _update_args(args_[k], ignore_none, **v)
args_[k] = update_kwargs(args_[k], ignore_none, **v)
else:
args_[k] = v
return args_


_update_args = update_kwargs # backward compatibility


def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:
"""
Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`.
Expand Down Expand Up @@ -318,7 +331,7 @@ def download(
so that the command line inputs can be simplified.

"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
name=name,
version=version,
Expand Down Expand Up @@ -834,7 +847,7 @@ def verify_metadata(

"""

_args = _update_args(
_args = update_kwargs(
args=args_file,
meta_file=meta_file,
filepath=filepath,
Expand Down Expand Up @@ -958,7 +971,7 @@ def verify_net_in_out(

"""

_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
meta_file=meta_file,
Expand Down Expand Up @@ -1127,7 +1140,7 @@ def onnx_export(
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
filepath=filepath,
Expand Down Expand Up @@ -1242,7 +1255,7 @@ def ckpt_export(
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
filepath=filepath,
Expand Down Expand Up @@ -1401,7 +1414,7 @@ def trt_export(
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
filepath=filepath,
Expand Down Expand Up @@ -1614,7 +1627,7 @@ def create_workflow(
kwargs: arguments to instantiate the workflow class.

"""
_args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
_log_input_summary(tag="run", args=_args)
(workflow_name, config_file) = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
Expand Down
18 changes: 10 additions & 8 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
import sys
import time
import warnings
from abc import ABC, abstractmethod
from copy import copy
from logging.config import fileConfig
Expand Down Expand Up @@ -158,7 +157,7 @@ def add_property(self, name: str, required: str, desc: str | None = None) -> Non
if self.properties is None:
self.properties = {}
if name in self.properties:
warnings.warn(f"property '{name}' already exists in the properties list, overriding it.")
logger.warn(f"property '{name}' already exists in the properties list, overriding it.")
self.properties[name] = {BundleProperty.DESC: desc, BundleProperty.REQUIRED: required}

def check_properties(self) -> list[str] | None:
Expand Down Expand Up @@ -241,7 +240,7 @@ def __init__(
for _config_file in _config_files:
_config_file = Path(_config_file)
if _config_file.parent != self.config_root_path:
warnings.warn(
logger.warn(
f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
f"not specified, {self.config_root_path} will be used as the default config root directory."
)
Expand All @@ -254,7 +253,7 @@ def __init__(
if logging_file is not None:
if not os.path.exists(logging_file):
if logging_file == str(self.config_root_path / "logging.conf"):
warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
else:
Expand All @@ -265,7 +264,10 @@ def __init__(
self.parser.read_config(f=config_file)
meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
if isinstance(meta_file, str) and not os.path.exists(meta_file):
raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
logger.error(
f"Cannot find the metadata config file: {meta_file}. "
"Please see: https://docs.monai.io/en/stable/mb_specification.html"
)
else:
self.parser.read_meta(f=meta_file)

Expand Down Expand Up @@ -323,17 +325,17 @@ def check_properties(self) -> list[str] | None:
"""
ret = super().check_properties()
if self.properties is None:
warnings.warn("No available properties had been set, skipping check.")
logger.warn("No available properties had been set, skipping check.")
return None
if ret:
warnings.warn(f"Loaded bundle does not contain the following required properties: {ret}")
logger.warn(f"Loaded bundle does not contain the following required properties: {ret}")
# also check whether the optional properties use correct ID name if existing
wrong_props = []
for n, p in self.properties.items():
if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p):
wrong_props.append(n)
if wrong_props:
warnings.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}")
logger.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}")
if ret is not None:
ret.extend(wrong_props)
return ret
Expand Down
2 changes: 1 addition & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs):
@property
def affine(self) -> torch.Tensor:
"""Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``"""
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine())
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) # type: ignore

@affine.setter
def affine(self, d: NdarrayTensor) -> None:
Expand Down
5 changes: 2 additions & 3 deletions runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function print_usage {
echo " -c, --clean : clean temporary files from tests and exit"
echo " -h, --help : show this help message and exit"
echo " -v, --version : show MONAI and system version information and exit"
echo " -p, --path : specify the path used for formatting"
echo " -p, --path : specify the path used for formatting, default is the current dir if unspecified"
echo " --formatfix : format code using \"isort\" and \"black\" for user specified directories"
echo ""
echo "${separator}For bug reports and feature requests, please file an issue at:"
Expand Down Expand Up @@ -359,10 +359,9 @@ if [ -e "$testdir" ]
then
homedir=$testdir
else
print_error_msg "Incorrect path: $testdir provided, run under $currentdir"
homedir=$currentdir
fi
echo "run tests under $homedir"
echo "Run tests under $homedir"
cd "$homedir"

# python path
Expand Down
2 changes: 2 additions & 0 deletions tests/test_bundle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from monai.bundle import update_kwargs
from monai.bundle.utils import load_bundle_config
from monai.networks.nets import UNet
from monai.utils import pprint_edges
Expand Down Expand Up @@ -141,6 +142,7 @@ def test_str(self):
"[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]",
)
self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3))
self.assertEqual(update_kwargs({"a": 1}, a=2, b=3), {"a": 2, "b": 3})


if __name__ == "__main__":
Expand Down
10 changes: 4 additions & 6 deletions tests/test_integration_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ def test_tiny(self):
with self.assertRaises(RuntimeError):
# test wrong run_id="run"
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
with self.assertRaises(RuntimeError):
# test missing meta file
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
# test missing meta file
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))

def test_scripts_fold(self):
# test scripts directory has been added to Python search directories automatically
Expand Down Expand Up @@ -150,9 +149,8 @@ def test_scripts_fold(self):
print(output)
self.assertTrue(expected_condition in output)

with self.assertRaises(RuntimeError):
# test missing meta file
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
# test missing meta file
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, config_file, expected_shape):
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def command_line_tests(cmd, copy_env=True):
try:
normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True)
print(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t"))
return repr(normal_out)
except subprocess.CalledProcessError as e:
output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t")
errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t")
Expand Down