Skip to content

Commit

Permalink
SlidingWindowInfererAdapt (#6251)
Browse files Browse the repository at this point in the history
SlidingWindowInfererAdapt extends SlidingWindowInferer to automatically
switch to buffered and then to CPU stitching, when OOM on GPU. It also
records a size of such large images to automatically try CPU stitching
for the next large image of a similar size. If the stitching 'device'
input parameter is provided,
automatic adaptation won't be attempted, please keep the default option
device = None for adaptive behavior.
Note: the output might be on CPU (even if the input was on GPU), if the
GPU memory was not sufficient.

---
also fixes #6340 by adding one line to the resampling

---------

Signed-off-by: myron <[email protected]>
Signed-off-by: Wenqi Li <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
myron and wyli authored Apr 14, 2023
1 parent 57c618c commit 3633b1c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 6 deletions.
10 changes: 9 additions & 1 deletion monai/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@

from __future__ import annotations

from .inferer import Inferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer
from .inferer import (
Inferer,
PatchInferer,
SaliencyInferer,
SimpleInferer,
SliceInferer,
SlidingWindowInferer,
SlidingWindowInfererAdapt,
)
from .merger import AvgMerger, Merger
from .splitter import SlidingWindowSplitter, Splitter
from .utils import sliding_window_inference
102 changes: 99 additions & 3 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,25 @@
import torch
import torch.nn as nn

from monai.apps.utils import get_logger
from monai.data.meta_tensor import MetaTensor
from monai.inferers.merger import AvgMerger, Merger
from monai.inferers.splitter import Splitter
from monai.inferers.utils import compute_importance_map, sliding_window_inference
from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
from monai.visualize import CAM, GradCAM, GradCAMpp

__all__ = ["Inferer", "PatchInferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer", "SliceInferer"]
logger = get_logger(__name__)

__all__ = [
"Inferer",
"PatchInferer",
"SimpleInferer",
"SlidingWindowInferer",
"SaliencyInferer",
"SliceInferer",
"SlidingWindowInfererAdapt",
]


class Inferer(ABC):
Expand Down Expand Up @@ -448,7 +459,9 @@ def __call__(
"""

device = self.device
device = kwargs.pop("device", self.device)
buffer_steps = kwargs.pop("buffer_steps", self.buffer_steps)

if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:
device = "cpu" # stitch in cpu memory if image is too large

Expand All @@ -467,13 +480,96 @@ def __call__(
self.progress,
self.roi_weight_map,
None,
self.buffer_steps,
buffer_steps,
self.buffer_dim,
*args,
**kwargs,
)


class SlidingWindowInfererAdapt(SlidingWindowInferer):
"""
SlidingWindowInfererAdapt extends SlidingWindowInferer to automatically switch to buffered and then to CPU stitching,
when OOM on GPU. It also records a size of such large images to automatically
try CPU stitching for the next large image of a similar size. If the stitching 'device' input parameter is provided,
automatic adaptation won't be attempted, please keep the default option device = None for adaptive behavior.
Note: the output might be on CPU (even if the input was on GPU), if the GPU memory was not sufficient.
"""

def __call__(
self,
inputs: torch.Tensor,
network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
*args: Any,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
"""
Args:
inputs: model input data for inference.
network: target model to execute inference.
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
args: optional args to be passed to ``network``.
kwargs: optional keyword args to be passed to ``network``.
"""

# if device is provided, use without any adaptations
if self.device is not None:
return super().__call__(inputs, network, *args, **kwargs)

skip_buffer = self.buffer_steps is not None and self.buffer_steps <= 0
cpu_cond = self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh
gpu_stitching = inputs.is_cuda and not cpu_cond
buffered_stitching = inputs.is_cuda and cpu_cond and not skip_buffer
buffer_steps = max(1, self.buffer_steps) if self.buffer_steps is not None else 1

for _ in range(10): # at most 10 trials
try:
return super().__call__(
inputs,
network,
device=inputs.device if gpu_stitching else torch.device("cpu"),
buffer_steps=buffer_steps if buffered_stitching else None,
*args,
**kwargs,
)
except RuntimeError as e:
if not gpu_stitching and not buffered_stitching or "OutOfMemoryError" not in str(type(e).__name__):
raise e

logger.info(e)

if gpu_stitching: # if failed on gpu
gpu_stitching = False
self.cpu_thresh = inputs.shape[2:].numel() - 1 # update thresh

if skip_buffer:
buffered_stitching = False
logger.warning(f"GPU stitching failed, attempting on CPU, image dim {inputs.shape}..")

else:
buffered_stitching = True
self.buffer_steps = buffer_steps
logger.warning(
f"GPU stitching failed, attempting with buffer {buffer_steps}, image dim {inputs.shape}.."
)
elif buffer_steps > 1:
buffer_steps = max(1, buffer_steps // 2)
self.buffer_steps = buffer_steps
logger.warning(
f"GPU buffered stitching failed, image dim {inputs.shape} reducing buffer to {buffer_steps}"
)
else:
buffered_stitching = False
self.buffer_steps = 0 # disable future buffer attempts
logger.warning(f"GPU buffered stitching failed, attempting on CPU, image dim {inputs.shape}")
raise RuntimeError( # not possible to finish after the trials
f"SlidingWindowInfererAdapt {skip_buffer} {cpu_cond} {gpu_stitching} {buffered_stitching} {buffer_steps}"
)


class SaliencyInferer(Inferer):
"""
SaliencyInferer is inference with activation maps.
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/lazy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None =
and allclose(convert_to_numpy(in_shape, wrap_sequence=True), out_spatial_size)
):
img.affine = call_kwargs["dst_affine"]
img = img.to(torch.float32) # consistent with monai.transforms.spatial.functional.spatial_resample
return img
img = monai.transforms.crop_or_pad_nd(img, matrix_np, out_spatial_size, mode=call_kwargs["padding_mode"])
img = img.to(torch.float32) # consistent with monai.transforms.spatial.functional.spatial_resample
Expand Down
5 changes: 4 additions & 1 deletion tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def rotate_90_2d():
return t


RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]])]
RESAMPLE_FUNCTION_CASES = [
(get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]]),
(get_arange_img((3, 3)), torch.eye(3), get_arange_img((3, 3))[0]),
]


class TestResampleFunction(unittest.TestCase):
Expand Down
7 changes: 6 additions & 1 deletion tests/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from parameterized import parameterized

from monai.data.utils import list_data_collate
from monai.inferers import SlidingWindowInferer, sliding_window_inference
from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference
from monai.utils import optional_import
from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick

Expand Down Expand Up @@ -305,6 +305,11 @@ def compute(data, test1, test2):
)(inputs, compute, t1, test2=t2)
np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)

result = SlidingWindowInfererAdapt(
roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm
)(inputs, compute, t1, test2=t2)
np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)

def test_multioutput(self):
device = "cuda" if torch.cuda.is_available() else "cpu:0"
inputs = torch.ones((1, 6, 20, 20)).to(device=device)
Expand Down

0 comments on commit 3633b1c

Please sign in to comment.