Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Pai <[email protected]>
  • Loading branch information
surajpaib committed Sep 27, 2024
1 parent fc10369 commit 17d0579
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 34 deletions.
41 changes: 16 additions & 25 deletions monai/networks/blocks/mednext_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]


def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
if spatial_dim == 2:
return nn.ConvTranspose2d if transpose else nn.Conv2d
else: # spatial_dim == 3
return nn.ConvTranspose3d if transpose else nn.Conv3d


class MedNeXtBlock(nn.Module):

def __init__(
Expand All @@ -39,18 +46,9 @@ def __init__(

self.do_res = use_residual_connection

assert dim in ["2d", "3d"]
self.dim = dim
if self.dim == "2d":
conv = nn.Conv2d
normalized_shape = [in_channels, kernel_size, kernel_size]
grn_parameter_shape = (1, 1)
elif self.dim == "3d":
conv = nn.Conv3d
normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size]
grn_parameter_shape = (1, 1, 1)
else:
raise ValueError("dim must be either '2d' or '3d'")
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
grn_parameter_shape = (1,) * (2 if dim == "2d" else 3)
# First convolution layer with DepthWise Convolutions
self.conv1 = conv(
in_channels=in_channels,
Expand All @@ -63,9 +61,11 @@ def __init__(

# Normalization Layer. GroupNorm is used by default.
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore
elif norm_type == "layer":
self.norm = nn.LayerNorm(normalized_shape=normalized_shape)
self.norm = nn.LayerNorm(
normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore
)
# Second convolution (Expansion) layer with Conv3D 1x1x1
self.conv2 = conv(
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
Expand Down Expand Up @@ -131,10 +131,7 @@ def __init__(
grn=grn,
)

if dim == "2d":
conv = nn.Conv2d
else:
conv = nn.Conv3d
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
self.resample_do_res = use_residual_connection
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
Expand Down Expand Up @@ -186,10 +183,7 @@ def __init__(
self.resample_do_res = use_residual_connection

self.dim = dim
if dim == "2d":
conv = nn.ConvTranspose2d
else:
conv = nn.ConvTranspose3d
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

Expand Down Expand Up @@ -228,10 +222,7 @@ class MedNeXtOutBlock(nn.Module):
def __init__(self, in_channels, n_classes, dim):
super().__init__()

if dim == "2d":
conv = nn.ConvTranspose2d
else:
conv = nn.ConvTranspose3d
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
self.conv_out = conv(in_channels, n_classes, kernel_size=1)

def forward(self, x):
Expand Down
16 changes: 8 additions & 8 deletions monai/networks/nets/mednext.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
init_filters: int = 32,
in_channels: int = 1,
out_channels: int = 2,
encoder_expansion_ratio: int = 2,
decoder_expansion_ratio: int = 2,
encoder_expansion_ratio: Sequence[int] | int = 2,
decoder_expansion_ratio: Sequence[int] | int = 2,
bottleneck_expansion_ratio: int = 2,
kernel_size: int = 7,
deep_supervision: bool = False,
Expand Down Expand Up @@ -212,7 +212,7 @@ def __init__(
out_blocks.reverse()
self.out_blocks = nn.ModuleList(out_blocks)

def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
"""
Forward pass of the MedNeXt model.
Expand All @@ -227,7 +227,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor or list[torch.Tensor]: Output tensor(s).
torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).
"""
# Apply stem convolution
x = self.stem(x)
Expand Down Expand Up @@ -311,7 +311,7 @@ def create_mednext(
blocks_down=(2, 2, 2, 2),
blocks_bottleneck=2,
blocks_up=(2, 2, 2, 2),
**common_args,
**common_args, # type: ignore
)
elif variant.upper() == "B":
return MedNeXt(
Expand All @@ -321,7 +321,7 @@ def create_mednext(
blocks_down=(2, 2, 2, 2),
blocks_bottleneck=2,
blocks_up=(2, 2, 2, 2),
**common_args,
**common_args, # type: ignore
)
elif variant.upper() == "M":
return MedNeXt(
Expand All @@ -331,7 +331,7 @@ def create_mednext(
blocks_down=(3, 4, 4, 4),
blocks_bottleneck=4,
blocks_up=(4, 4, 4, 3),
**common_args,
**common_args, # type: ignore
)
elif variant.upper() == "L":
return MedNeXt(
Expand All @@ -341,7 +341,7 @@ def create_mednext(
blocks_down=(3, 4, 8, 8),
blocks_bottleneck=8,
blocks_up=(8, 8, 4, 3),
**common_args,
**common_args, # type: ignore
)
else:
raise ValueError(f"Invalid MedNeXt variant: {variant}")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mednext.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
for spatial_dims in range(2, 4):
for out_channels in [1, 2]:
test_case = [
model,
model, # type: ignore
{"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels},
(2, 1, *([16] * spatial_dims)),
(2, out_channels, *([16] * spatial_dims)),
Expand Down

0 comments on commit 17d0579

Please sign in to comment.