Skip to content

Commit

Permalink
8134 Add unit test for responsive inference (#8146)
Browse files Browse the repository at this point in the history
Fixes #8134 .

### Description

This PR added unit test to cover the realtime inference with bundles.
And updated `BundleWorkflow` to support cyclically calling the `run`
function with all components instantiated.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Nic Ma <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
4 people authored Nov 24, 2024
1 parent b1e915c commit d94df3f
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 3 deletions.
10 changes: 10 additions & 0 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str
"""
return self._resolve_one_item(id=id, **kwargs)

def remove_resolved_content(self, id: str) -> Any | None:
"""
Remove the resolved ``ConfigItem`` by id.
Args:
id: id name of the expected item.
"""
return self.resolved_content.pop(id) if id in self.resolved_content else None

@classmethod
def normalize_id(cls, id: str | int) -> str:
"""
Expand Down
19 changes: 17 additions & 2 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,23 @@ def check_properties(self) -> list[str] | None:
ret.extend(wrong_props)
return ret

def _run_expr(self, id: str, **kwargs: dict) -> Any:
return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
"""
Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
"""
ret = []
if id in self.parser:
# suppose all the expressions are in a list, run and reset the expressions
if isinstance(self.parser[id], list):
for i in range(len(self.parser[id])):
sub_id = f"{id}{ID_SEP_KEY}{i}"
ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
self.parser.ref_resolver.remove_resolved_content(sub_id)
else:
ret.append(self.parser.get_parsed_content(id, **kwargs))
self.parser.ref_resolver.remove_resolved_content(id)
return ret

def _get_prop_id(self, name: str, property: dict) -> Any:
prop_id = property[BundlePropertyConfig.ID]
Expand Down
35 changes: 34 additions & 1 deletion tests/test_bundle_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.data import Dataset
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.networks.nets import UNet
from monai.transforms import Compose, LoadImage
from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged
from tests.nonconfig_workflow import NonConfigWorkflow

TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")]
Expand All @@ -35,6 +35,8 @@

TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]

TEST_CASE_4 = [os.path.join(os.path.dirname(__file__), "testing_data", "responsive_inference.json")]

TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."]


Expand All @@ -45,7 +47,9 @@ def setUp(self):
self.expected_shape = (128, 128, 128)
test_image = np.random.rand(*self.expected_shape)
self.filename = os.path.join(self.data_dir, "image.nii")
self.filename1 = os.path.join(self.data_dir, "image1.nii")
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1)

def tearDown(self):
shutil.rmtree(self.data_dir)
Expand Down Expand Up @@ -115,6 +119,35 @@ def test_inference_config(self, config_file):
self._test_inferer(inferer)
self.assertEqual(inferer.workflow_type, None)

@parameterized.expand([TEST_CASE_4])
def test_responsive_inference_config(self, config_file):
input_loader = LoadImaged(keys="image")
output_saver = SaveImaged(keys="pred", output_dir=self.data_dir, output_postfix="seg")

# test standard MONAI model-zoo config workflow
inferer = ConfigWorkflow(
workflow_type="infer",
config_file=config_file,
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
)
# FIXME: temp add the property for test, we should add it to some formal realtime infer properties
inferer.add_property(name="dataflow", required=True, config_id="dataflow")

inferer.initialize()
inferer.dataflow.update(input_loader({"image": self.filename}))
inferer.run()
output_saver(inferer.dataflow)
self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image", "image_seg.nii.gz")))

# bundle is instantiated and idle, just change the input for next inference
inferer.dataflow.clear()
inferer.dataflow.update(input_loader({"image": self.filename1}))
inferer.run()
output_saver(inferer.dataflow)
self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image1", "image1_seg.nii.gz")))

inferer.finalize()

@parameterized.expand([TEST_CASE_3])
def test_train_config(self, config_file):
# test standard MONAI model-zoo config workflow
Expand Down
101 changes: 101 additions & 0 deletions tests/testing_data/responsive_inference.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"imports": [
"$from collections import defaultdict"
],
"bundle_root": "will override",
"device": "$torch.device('cpu')",
"network_def": {
"_target_": "UNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"channels": [
2,
2,
4,
8,
4
],
"strides": [
2,
2,
2,
2
],
"num_res_units": 2,
"norm": "batch"
},
"network": "$@network_def.to(@device)",
"dataflow": "$defaultdict()",
"preprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "EnsureChannelFirstd",
"keys": "image"
},
{
"_target_": "ScaleIntensityd",
"keys": "image"
},
{
"_target_": "RandRotated",
"_disabled_": true,
"keys": "image"
}
]
},
"dataset": {
"_target_": "Dataset",
"data": [
"@dataflow"
],
"transform": "@preprocessing"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@dataset",
"batch_size": 1,
"shuffle": false,
"num_workers": 0
},
"inferer": {
"_target_": "SlidingWindowInferer",
"roi_size": [
64,
64,
32
],
"sw_batch_size": 4,
"overlap": 0.25
},
"postprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "Activationsd",
"keys": "pred",
"softmax": true
},
{
"_target_": "AsDiscreted",
"keys": "pred",
"argmax": true
}
]
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"device": "@device",
"val_data_loader": "@dataloader",
"network": "@network",
"inferer": "@inferer",
"postprocessing": "@postprocessing",
"amp": false,
"epoch_length": 1
},
"run": [
"[email protected]()",
"[email protected](@evaluator.state.output[0])"
]
}

0 comments on commit d94df3f

Please sign in to comment.