Skip to content

Commit

Permalink
[DLMED] update patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
Nic-Ma committed Mar 10, 2022
1 parent ca16681 commit e9f5179
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 268 deletions.
26 changes: 12 additions & 14 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ class ConfigParser:
config = {
"my_dims": 2,
"dims_1": "$@my_dims + 1",
"my_xform": {"<name>": "LoadImage"},
"my_net": {"<name>": "BasicUNet",
"<args>": {"spatial_dims": "@dims_1", "in_channels": 1, "out_channels": 4}},
"trainer": {"<name>": "SupervisedTrainer",
"<args>": {"network": "@my_net", "preprocessing": "@my_xform"}}
"my_xform": {"_name_": "LoadImage"},
"my_net": {"_name_": "BasicUNet", "spatial_dims": "@dims_1", "in_channels": 1, "out_channels": 4},
"trainer": {"_name_": "SupervisedTrainer", "network": "@my_net", "preprocessing": "@my_xform"}
}
# in the example $@my_dims + 1 is an expression, which adds 1 to the value of @my_dims
parser = ConfigParser(config)
# get/set configuration content, the set method should happen before calling parse()
print(parser["my_net"]["<args>"]["in_channels"]) # original input channels 1
parser["my_net"]["<args>"]["in_channels"] = 4 # change input channels to 4
print(parser["my_net"]["<args>"]["in_channels"])
print(parser["my_net"]["in_channels"]) # original input channels 1
parser["my_net"]["in_channels"] = 4 # change input channels to 4
print(parser["my_net"]["in_channels"])
# instantiate the network component
parser.parse(True)
Expand Down Expand Up @@ -107,7 +105,7 @@ def __getitem__(self, id: Union[str, int]):
id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
go one level further into the nested structures.
Use digits indexing from "0" for list or other strings for dict.
For example: ``"xform#5"``, ``"net#<args>#channels"``. ``""`` indicates the entire ``self.config``.
For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.
"""
if id == "":
Expand All @@ -129,7 +127,7 @@ def __setitem__(self, id: Union[str, int], config: Any):
id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
go one level further into the nested structures.
Use digits indexing from "0" for list or other strings for dict.
For example: ``"xform#5"``, ``"net#<args>#channels"``. ``""`` indicates the entire ``self.config``.
For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.
config: config to set at location ``id``.
"""
Expand Down Expand Up @@ -171,7 +169,7 @@ def _do_resolve(self, config: Any):
"""
Recursively resolve the config content to replace the macro tokens with target content.
The macro tokens start with "%", can be from another structured file, like:
``{"net": "%default_net"}``, ``{"net": "%/data/config.json#net#<args>"}``.
``{"net": "%default_net"}``, ``{"net": "%/data/config.json#net"}``.
Args:
config: input config file to resolve.
Expand All @@ -190,7 +188,7 @@ def resolve_macro(self):
"""
Recursively resolve `self.config` to replace the macro tokens with target content.
The macro tokens are marked as starting with "%", can be from another structured file, like:
``"%default_net"``, ``"%/data/config.json#net#<args>"``.
``"%default_net"``, ``"%/data/config.json#net"``.
"""
self.set(self._do_resolve(config=deepcopy(self.get())))
Expand All @@ -204,7 +202,7 @@ def _do_parse(self, config, id: str = ""):
id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
go one level further into the nested structures.
Use digits indexing from "0" for list or other strings for dict.
For example: ``"xform#5"``, ``"net#<args>#channels"``. ``""`` indicates the entire ``self.config``.
For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.
"""
if isinstance(config, (dict, list)):
Expand Down Expand Up @@ -248,7 +246,7 @@ def get_parsed_content(self, id: str = "", **kwargs):
id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
go one level further into the nested structures.
Use digits indexing from "0" for list or other strings for dict.
For example: ``"xform#5"``, ``"net#<args>#channels"``. ``""`` indicates the entire ``self.config``.
For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.
kwargs: additional keyword arguments to be passed to ``_resolve_one_item``.
Currently support ``reset`` (for parse), ``instantiate`` and ``eval_expr``. All defaulting to True.
Expand Down
4 changes: 2 additions & 2 deletions monai/bundle/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConfigReader:
suffixes = ("json", "yaml", "yml")
suffix_match = rf"\.({'|'.join(suffixes)})"
path_match = rf"(.*{suffix_match}$)"
meta_key = "<meta>" # field key to save metadata
meta_key = "_meta_" # field key to save metadata
sep = "#" # separator for file path and the id of content in the file

def __init__(self):
Expand Down Expand Up @@ -122,7 +122,7 @@ def split_path_id(cls, src: str) -> Tuple[str, str]:
def read_meta(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs):
"""
Read the metadata from specified JSON or YAML file.
The metadata as a dictionary will be stored at ``self.config["<meta>"]``.
The metadata as a dictionary will be stored at ``self.config["_meta_"]``.
Args:
f: filepath of the metadata file, the content must be a dictionary,
Expand Down
4 changes: 2 additions & 2 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class ReferenceResolver:
_vars = "__local_refs"
sep = "#" # separator for key indexing
ref = "@" # reference prefix
# match a reference string, e.g. "@id#key", "@id#key#0", "@<test>#<args>#key"
id_matcher = re.compile(rf"{ref}(?:(?:<\w*>)|(?:\w*))(?:(?:{sep}<\w*>)|(?:{sep}\w*))*")
# match a reference string, e.g. "@id#key", "@id#key#0", "@_name_#key"
id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*")

def __init__(self, items: Optional[Sequence[ConfigItem]] = None):
# save the items in a dictionary with the `ConfigItem.id` as key
Expand Down
8 changes: 4 additions & 4 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def run(
python -m monai.bundle run --meta_file <meta path> --config_file <config path> --target_id trainer
# Override config values at runtime by specifying the component id and its new value:
python -m monai.bundle run --"net#<args>#input_chns" 1 ...
python -m monai.bundle run --net#input_chns 1 ...
# Override config values with another config file `/path/to/another.json`:
python -m monai.bundle run --"net#<args>" "%/path/to/another.json" ...
python -m monai.bundle run --net %/path/to/another.json ...
# Override config values with part content of another config file:
python -m monai.bundle run --"net#<args>" "%/data/other.json#net_arg" ...
python -m monai.bundle run --net %/data/other.json#net_arg ...
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime:
Expand All @@ -55,7 +55,7 @@ def run(
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
`target_id` and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--"net#<args>#input_chns" 42``.
e.g. ``--net#input_chns 42``.
"""
k_v = zip(["meta_file", "config_file", "target_id"], [meta_file, config_file, target_id])
Expand Down
8 changes: 4 additions & 4 deletions tests/test_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def test_shape(self, config_file, expected_shape):
saver = LoadImage(image_only=True)

if sys.platform == "win32":
override = "--network $@network_def.to(@device) --dataset#<name> Dataset"
override = "--network $@network_def.to(@device) --dataset#_name_ Dataset"
else:
override = f"--network %{overridefile1}#move_net --dataset#<name> %{overridefile2}"
override = f"--network %{overridefile1}#move_net --dataset#_name_ %{overridefile2}"
# test with `monai.bundle` as CLI entry directly
cmd = "-m monai.bundle run --target_id evaluator"
cmd += f" --postprocessing#<args>#transforms#2#<args>#output_postfix seg {override}"
cmd += f" --postprocessing#transforms#2#output_postfix seg {override}"
la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
ret = subprocess.check_call(la + ["--args_file", def_args_file])
self.assertEqual(ret, 0)
self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)

# here test the script with `google fire` tool as CLI
cmd = "-m fire monai.bundle.scripts run --target_id evaluator"
cmd += f" --evaluator#<args>#amp False {override}"
cmd += f" --evaluator#amp False {override}"
la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
ret = subprocess.check_call(la)
self.assertEqual(ret, 0)
Expand Down
5 changes: 1 addition & 4 deletions tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@
# test `_disabled_` with string
TEST_CASE_5 = [{"_name_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict]
# test non-monai modules and excludes
TEST_CASE_6 = [
{"_path_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4},
torch.optim.Adam,
]
TEST_CASE_6 = [{"_path_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam]
TEST_CASE_7 = [{"_name_": "decollate_batch", "detach": True, "pad": True}, partial]
# test args contains "name" field
TEST_CASE_8 = [
Expand Down
49 changes: 23 additions & 26 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,21 @@
TEST_CASE_1 = [
{
"transform": {
"<name>": "Compose",
"<args>": {
"transforms": [
{"<name>": "LoadImaged", "<args>": {"keys": "image"}},
{
"<name>": "RandTorchVisiond",
"<args>": {"keys": "image", "name": "ColorJitter", "brightness": 0.25},
},
]
},
"_name_": "Compose",
"transforms": [
{"_name_": "LoadImaged", "keys": "image"},
{"_name_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
],
},
"dataset": {"<name>": "Dataset", "<args>": {"data": [1, 2], "transform": "@transform"}},
"dataset": {"_name_": "Dataset", "data": [1, 2], "transform": "@transform"},
"dataloader": {
"<name>": "DataLoader",
"<args>": {"dataset": "@dataset", "batch_size": 2, "collate_fn": "monai.data.list_data_collate"},
"_name_": "DataLoader",
"dataset": "@dataset",
"batch_size": 2,
"collate_fn": "monai.data.list_data_collate",
},
},
["transform", "transform#<args>#transforms#0", "transform#<args>#transforms#1", "dataset", "dataloader"],
["transform", "transform#transforms#0", "transform#transforms#1", "dataset", "dataloader"],
[Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader],
]

Expand All @@ -67,9 +64,9 @@ def __call__(self, a, b):
"cls_func": "$TestClass.cls_compute",
"lambda_static_func": "$lambda x, y: TestClass.compute(x, y)",
"lambda_cls_func": "$lambda x, y: TestClass.cls_compute(x, y)",
"compute": {"<path>": "tests.test_config_parser.TestClass.compute", "<args>": {"func": "@basic_func"}},
"cls_compute": {"<path>": "tests.test_config_parser.TestClass.cls_compute", "<args>": {"func": "@basic_func"}},
"call_compute": {"<path>": "tests.test_config_parser.TestClass"},
"compute": {"_path_": "tests.test_config_parser.TestClass.compute", "func": "@basic_func"},
"cls_compute": {"_path_": "tests.test_config_parser.TestClass.cls_compute", "func": "@basic_func"},
"call_compute": {"_path_": "tests.test_config_parser.TestClass"},
"error_func": "$TestClass.__call__",
"<test>": "$lambda x, y: x + y",
}
Expand All @@ -78,17 +75,17 @@ def __call__(self, a, b):

class TestConfigComponent(unittest.TestCase):
def test_config_content(self):
test_config = {"preprocessing": [{"<name>": "LoadImage"}], "dataset": {"<name>": "Dataset"}}
test_config = {"preprocessing": [{"_name_": "LoadImage"}], "dataset": {"_name_": "Dataset"}}
parser = ConfigParser(config=test_config)
# test `get`, `set`, `__getitem__`, `__setitem__`
self.assertEqual(str(parser.get()), str(test_config))
parser.set(config=test_config)
self.assertListEqual(parser["preprocessing"], test_config["preprocessing"])
parser["dataset"] = {"<name>": "CacheDataset"}
self.assertEqual(parser["dataset"]["<name>"], "CacheDataset")
parser["dataset"] = {"_name_": "CacheDataset"}
self.assertEqual(parser["dataset"]["_name_"], "CacheDataset")
# test nested ids
parser["dataset#<name>"] = "Dataset"
self.assertEqual(parser["dataset#<name>"], "Dataset")
parser["dataset#_name_"] = "Dataset"
self.assertEqual(parser["dataset#_name_"], "Dataset")
# test int id
parser.set(["test1", "test2", "test3"])
parser[1] = "test4"
Expand All @@ -99,11 +96,11 @@ def test_config_content(self):
def test_parse(self, config, expected_ids, output_types):
parser = ConfigParser(config=config, globals={"monai": "monai"})
# test lazy instantiation with original config content
parser["transform"]["<args>"]["transforms"][0]["<args>"]["keys"] = "label1"
self.assertEqual(parser.get_parsed_content(id="transform#<args>#transforms#0").keys[0], "label1")
parser["transform"]["transforms"][0]["keys"] = "label1"
self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1")
# test nested id
parser["transform#<args>#transforms#0#<args>#keys"] = "label2"
self.assertEqual(parser.get_parsed_content(id="transform#<args>#transforms#0").keys[0], "label2")
parser["transform#transforms#0#keys"] = "label2"
self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label2")
for id, cls in zip(expected_ids, output_types):
self.assertTrue(isinstance(parser.get_parsed_content(id), cls))
# test root content
Expand Down
48 changes: 19 additions & 29 deletions tests/test_reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
TEST_CASE_1 = [
{
# all the recursively parsed config items
"transform#1": {"<name>": "LoadImaged", "<args>": {"keys": ["image"]}},
"transform#1#<name>": "LoadImaged",
"transform#1#<args>": {"keys": ["image"]},
"transform#1#<args>#keys": ["image"],
"transform#1#<args>#keys#0": "image",
"transform#1": {"_name_": "LoadImaged", "keys": ["image"]},
"transform#1#_name_": "LoadImaged",
"transform#1#keys": ["image"],
"transform#1#keys#0": "image",
},
"transform#1",
LoadImaged,
Expand All @@ -40,20 +39,15 @@
TEST_CASE_2 = [
{
# some the recursively parsed config items
"dataloader": {
"<name>": "DataLoader",
"<args>": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"},
},
"dataset": {"<name>": "Dataset", "<args>": {"data": [1, 2]}},
"dataloader#<name>": "DataLoader",
"dataloader#<args>": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"},
"dataloader#<args>#dataset": "@dataset",
"dataloader#<args>#collate_fn": "$monai.data.list_data_collate",
"dataset#<name>": "Dataset",
"dataset#<args>": {"data": [1, 2]},
"dataset#<args>#data": [1, 2],
"dataset#<args>#data#0": 1,
"dataset#<args>#data#1": 2,
"dataloader": {"_name_": "DataLoader", "dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"},
"dataset": {"_name_": "Dataset", "data": [1, 2]},
"dataloader#_name_": "DataLoader",
"dataloader#dataset": "@dataset",
"dataloader#collate_fn": "$monai.data.list_data_collate",
"dataset#_name_": "Dataset",
"dataset#data": [1, 2],
"dataset#data#0": 1,
"dataset#data#1": 2,
},
"dataloader",
DataLoader,
Expand All @@ -62,15 +56,11 @@
TEST_CASE_3 = [
{
# all the recursively parsed config items
"transform#1": {
"<name>": "RandTorchVisiond",
"<args>": {"keys": "image", "name": "ColorJitter", "brightness": 0.25},
},
"transform#1#<name>": "RandTorchVisiond",
"transform#1#<args>": {"keys": "image", "name": "ColorJitter", "brightness": 0.25},
"transform#1#<args>#keys": "image",
"transform#1#<args>#name": "ColorJitter",
"transform#1#<args>#brightness": 0.25,
"transform#1": {"_name_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
"transform#1#_name_": "RandTorchVisiond",
"transform#1#keys": "image",
"transform#1#name": "ColorJitter",
"transform#1#brightness": 0.25,
},
"transform#1",
RandTorchVisiond,
Expand All @@ -97,7 +87,7 @@ def test_resolve(self, configs, expected_id, output_type):
# test lazy instantiation
item = resolver.get_item(expected_id, resolve=True)
config = item.get_config()
config["<disabled>"] = False
config["_disabled_"] = False
item.update_config(config=config)
if isinstance(item, ConfigComponent):
result = item.instantiate()
Expand Down
Loading

0 comments on commit e9f5179

Please sign in to comment.