From 2016f839c2cd1e492fc38b89f6b1f9a7224e4385 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 29 Jul 2023 18:01:02 +0100 Subject: [PATCH] update Signed-off-by: Wenqi Li --- monai/apps/pathology/inferers/inferer.py | 1 + monai/inferers/inferer.py | 5 +++++ monai/inferers/utils.py | 8 +++++++- tests/test_sliding_window_hovernet_inference.py | 1 + tests/test_sliding_window_inference.py | 1 + 5 files changed, 15 insertions(+), 1 deletion(-) diff --git a/monai/apps/pathology/inferers/inferer.py b/monai/apps/pathology/inferers/inferer.py index 7a60c23aa20..71259ca7dfd 100644 --- a/monai/apps/pathology/inferers/inferer.py +++ b/monai/apps/pathology/inferers/inferer.py @@ -178,6 +178,7 @@ def __call__( self.process_output, self.buffer_steps, self.buffer_dim, + False, *args, **kwargs, ) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 5484970d824..bf8c27e5c36 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -426,6 +426,8 @@ class SlidingWindowInferer(Inferer): (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. buffer_dim: the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. + with_coord: whether to pass the window coordinates to ``network``. Defaults to False. + If True, the ``network``'s 2nd input argument should accept the window coordinates. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, @@ -449,6 +451,7 @@ def __init__( cpu_thresh: int | None = None, buffer_steps: int | None = None, buffer_dim: int = -1, + with_coord: bool = False, ) -> None: super().__init__() self.roi_size = roi_size @@ -464,6 +467,7 @@ def __init__( self.cpu_thresh = cpu_thresh self.buffer_steps = buffer_steps self.buffer_dim = buffer_dim + self.with_coord = with_coord # compute_importance_map takes long time when computing on cpu. We thus # compute it once if it's static and then save it for future usage @@ -525,6 +529,7 @@ def __call__( None, buffer_steps, buffer_dim, + self.with_coord, *args, **kwargs, ) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 92f267e8a2c..a080284e7ca 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -57,6 +57,7 @@ def sliding_window_inference( process_fn: Callable | None = None, buffer_steps: int | None = None, buffer_dim: int = -1, + with_coord: bool = False, *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: @@ -125,6 +126,8 @@ def sliding_window_inference( (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. buffer_dim: the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. + with_coord: whether to pass the window coordinates to ``predictor``. Default is False. + If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -220,7 +223,10 @@ def sliding_window_inference( win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) else: win_data = inputs[unravel_slice[0]].to(sw_device) - seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch + if with_coord: + seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch + else: + seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. dict_keys, seg_tuple = _flatten_struct(seg_prob_out) diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index b17e8525ec3..276bd1e3723 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -237,6 +237,7 @@ def compute(data, test1, test2): None, None, 0, + False, t1, test2=t2, ) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index f9d49361a61..8f0c0744033 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -294,6 +294,7 @@ def compute(data, test1, test2): None, None, 0, + False, t1, test2=t2, )