Skip to content

Commit

Permalink
Code formating for Blake and Flake8 checks to pass + integration of M…
Browse files Browse the repository at this point in the history
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder

Signed-off-by: Robin CREMESE <[email protected]>
  • Loading branch information
rcremese committed Sep 2, 2024
1 parent d7661dd commit 0361444
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 103 deletions.
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .fcn import FCN, GCN, MCFCN, Refine
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
Expand Down
123 changes: 64 additions & 59 deletions monai/networks/blocks/mednext_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]


class MedNeXtBlock(nn.Module):
Expand All @@ -26,63 +27,65 @@ def __init__(
self,
in_channels: int,
out_channels: int,
exp_r: int = 4,
expansion_ratio: int = 4,
kernel_size: int = 7,
do_res: int = True,
use_residual_connection: int = True,
norm_type: str = "group",
n_groups: int or None = None,
dim="3d",
grn=False,
):

super().__init__()

self.do_res = do_res
self.do_res = use_residual_connection

assert dim in ["2d", "3d"]
self.dim = dim
if self.dim == "2d":
conv = nn.Conv2d
else:
normalized_shape = [in_channels, kernel_size, kernel_size]
grn_parameter_shape = (1, 1)
elif self.dim == "3d":
conv = nn.Conv3d

normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size]
grn_parameter_shape = (1, 1, 1)
else:
raise ValueError("dim must be either '2d' or '3d'")
# First convolution layer with DepthWise Convolutions
self.conv1 = conv(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=in_channels if n_groups is None else n_groups,
groups=in_channels,
)

# Normalization Layer. GroupNorm is used by default.
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
elif norm_type == "layer":
self.norm = LayerNorm(normalized_shape=in_channels, data_format="channels_first")

self.norm = nn.LayerNorm(normalized_shape=normalized_shape)
# Second convolution (Expansion) layer with Conv3D 1x1x1
self.conv2 = conv(in_channels=in_channels, out_channels=exp_r * in_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = conv(
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
)

# GeLU activations
self.act = nn.GELU()

# Third convolution (Compression) layer with Conv3D 1x1x1
self.conv3 = conv(
in_channels=exp_r * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
)

self.grn = grn
if self.grn:
if dim == "2d":
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
else:
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)

def forward(self, x, dummy_tensor=None):
def forward(self, x):

x1 = x
x1 = self.conv1(x1)
Expand All @@ -106,19 +109,34 @@ def forward(self, x, dummy_tensor=None):
class MedNeXtDownBlock(MedNeXtBlock):

def __init__(
self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: bool = False,
norm_type: str = "group",
dim: str = "3d",
grn: bool = False,
):

super().__init__(
in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn
in_channels,
out_channels,
expansion_ratio,
kernel_size,
use_residual_connection=False,
norm_type=norm_type,
dim=dim,
grn=grn,
)

if dim == "2d":
conv = nn.Conv2d
else:
conv = nn.Conv3d
self.resample_do_res = do_res
if do_res:
self.resample_do_res = use_residual_connection
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

self.conv1 = conv(
Expand All @@ -130,7 +148,7 @@ def __init__(
groups=in_channels,
)

def forward(self, x, dummy_tensor=None):
def forward(self, x):

x1 = super().forward(x)

Expand All @@ -144,20 +162,35 @@ def forward(self, x, dummy_tensor=None):
class MedNeXtUpBlock(MedNeXtBlock):

def __init__(
self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: bool = False,
norm_type: str = "group",
dim: str = "3d",
grn: bool = False,
):
super().__init__(
in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn
in_channels,
out_channels,
expansion_ratio,
kernel_size,
use_residual_connection=False,
norm_type=norm_type,
dim=dim,
grn=grn,
)

self.resample_do_res = do_res
self.resample_do_res = use_residual_connection

self.dim = dim
if dim == "2d":
conv = nn.ConvTranspose2d
else:
conv = nn.ConvTranspose3d
if do_res:
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

self.conv1 = conv(
Expand All @@ -169,7 +202,7 @@ def __init__(
groups=in_channels,
)

def forward(self, x, dummy_tensor=None):
def forward(self, x):

x1 = super().forward(x)
# Asymmetry but necessary to match shape
Expand All @@ -190,7 +223,7 @@ def forward(self, x, dummy_tensor=None):
return x1


class OutBlock(nn.Module):
class MedNeXtOutBlock(nn.Module):

def __init__(self, in_channels, n_classes, dim):
super().__init__()
Expand All @@ -201,33 +234,5 @@ def __init__(self, in_channels, n_classes, dim):
conv = nn.ConvTranspose3d
self.conv_out = conv(in_channels, n_classes, kernel_size=1)

def forward(self, x, dummy_tensor=None):
def forward(self, x):
return self.conv_out(x)


class LayerNorm(nn.Module):
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""

def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta
self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)

def forward(self, x, dummy_tensor=False):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
return x
20 changes: 19 additions & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,25 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .mednext import MedNeXt
from .mednext import (
MedNeXt,
MedNext,
MedNextB,
MedNeXtB,
MedNextBase,
MedNextL,
MedNeXtL,
MedNeXtLarge,
MedNextLarge,
MedNextM,
MedNeXtM,
MedNeXtMedium,
MedNextMedium,
MedNextS,
MedNeXtS,
MedNeXtSmall,
MedNextSmall,
)
from .milmodel import MILModel
from .netadapter import NetAdapter
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
Expand Down
Loading

0 comments on commit 0361444

Please sign in to comment.