Skip to content

Commit

Permalink
SwinUNETR refactored img_size parameter and removed checkpointing dep…
Browse files Browse the repository at this point in the history
…recation

Signed-off-by: John Zielke <[email protected]>
  • Loading branch information
john-zielke-snkeos committed Oct 6, 2023
1 parent 100db27 commit 2b01bd4
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
from typing_extensions import Final

from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
from monai.utils.deprecate_utils import deprecated_arg

rearrange, _ = optional_import("einops", name="rearrange")

Expand All @@ -49,6 +51,15 @@ class SwinUNETR(nn.Module):
<https://arxiv.org/abs/2201.01266>"
"""

patch_size: Final[int] = 2

@deprecated_arg(
name="img_size",
since="1.3",
removed="1.5",
msg_suffix="The img_size argument is not required anymore and "
"checks on the input size are run during forward().",
)
def __init__(
self,
img_size: Sequence[int] | int,
Expand All @@ -69,7 +80,10 @@ def __init__(
) -> None:
"""
Args:
img_size: dimension of input image.
img_size: spatial dimension of input image.
This argument is only used for checking that the input image size is divisible by the patch size.
The actual tensor shape can be different as long as it is divisible by 2**5.
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
feature_size: dimension of network feature size.
Expand Down Expand Up @@ -103,16 +117,13 @@ def __init__(
super().__init__()

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(2, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(7, spatial_dims)

if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")

for m, p in zip(img_size, patch_size):
for i in range(5):
if m % np.power(p, i + 1) != 0:
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
self._check_input_size(img_size)

if not (0 <= drop_rate <= 1):
raise ValueError("dropout rate should be between 0 and 1.")
Expand All @@ -132,7 +143,7 @@ def __init__(
in_chans=in_channels,
embed_dim=feature_size,
window_size=window_size,
patch_size=patch_size,
patch_size=patch_sizes,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
Expand Down Expand Up @@ -297,7 +308,20 @@ def load_from(self, weights):
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
)

@torch.jit.unused
def _check_input_size(self, spatial_shape):
img_size = np.array(spatial_shape)
remainder = (img_size % np.power(self.patch_size, 5)) > 0
if remainder.any():
wrong_dims = (np.where(remainder)[0] + 2).tolist()
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by {self.patch_size}**5."
)

def forward(self, x_in):
if not torch.jit.is_scripting():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
enc1 = self.encoder2(hidden_states_out[0])
Expand Down Expand Up @@ -669,12 +693,12 @@ def load_from(self, weights, n_block, layer):
def forward(self, x, mask_matrix):
shortcut = x
if self.use_checkpoint:
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)
else:
x = self.forward_part1(x, mask_matrix)
x = shortcut + self.drop_path(x)
if self.use_checkpoint:
x = x + checkpoint.checkpoint(self.forward_part2, x)
x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)
else:
x = x + self.forward_part2(x)
return x
Expand Down

0 comments on commit 2b01bd4

Please sign in to comment.