Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mednext #1

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .fcn import FCN, GCN, MCFCN, Refine
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
Expand Down
123 changes: 64 additions & 59 deletions monai/networks/blocks/mednext_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]


class MedNeXtBlock(nn.Module):
Expand All @@ -26,63 +27,65 @@ def __init__(
self,
in_channels: int,
out_channels: int,
exp_r: int = 4,
expansion_ratio: int = 4,
kernel_size: int = 7,
do_res: int = True,
use_residual_connection: int = True,
norm_type: str = "group",
n_groups: int or None = None,
dim="3d",
grn=False,
):

super().__init__()

self.do_res = do_res
self.do_res = use_residual_connection

assert dim in ["2d", "3d"]
self.dim = dim
if self.dim == "2d":
conv = nn.Conv2d
else:
normalized_shape = [in_channels, kernel_size, kernel_size]
grn_parameter_shape = (1, 1)
elif self.dim == "3d":
conv = nn.Conv3d

normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size]
grn_parameter_shape = (1, 1, 1)
else:
raise ValueError("dim must be either '2d' or '3d'")
# First convolution layer with DepthWise Convolutions
self.conv1 = conv(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=in_channels if n_groups is None else n_groups,
groups=in_channels,
)

# Normalization Layer. GroupNorm is used by default.
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
elif norm_type == "layer":
self.norm = LayerNorm(normalized_shape=in_channels, data_format="channels_first")

self.norm = nn.LayerNorm(normalized_shape=normalized_shape)
# Second convolution (Expansion) layer with Conv3D 1x1x1
self.conv2 = conv(in_channels=in_channels, out_channels=exp_r * in_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = conv(
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
)

# GeLU activations
self.act = nn.GELU()

# Third convolution (Compression) layer with Conv3D 1x1x1
self.conv3 = conv(
in_channels=exp_r * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
)

self.grn = grn
if self.grn:
if dim == "2d":
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
else:
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)

def forward(self, x, dummy_tensor=None):
def forward(self, x):

x1 = x
x1 = self.conv1(x1)
Expand All @@ -106,19 +109,34 @@ def forward(self, x, dummy_tensor=None):
class MedNeXtDownBlock(MedNeXtBlock):

def __init__(
self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: bool = False,
norm_type: str = "group",
dim: str = "3d",
grn: bool = False,
):

super().__init__(
in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn
in_channels,
out_channels,
expansion_ratio,
kernel_size,
use_residual_connection=False,
norm_type=norm_type,
dim=dim,
grn=grn,
)

if dim == "2d":
conv = nn.Conv2d
else:
conv = nn.Conv3d
self.resample_do_res = do_res
if do_res:
self.resample_do_res = use_residual_connection
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

self.conv1 = conv(
Expand All @@ -130,7 +148,7 @@ def __init__(
groups=in_channels,
)

def forward(self, x, dummy_tensor=None):
def forward(self, x):

x1 = super().forward(x)

Expand All @@ -144,20 +162,35 @@ def forward(self, x, dummy_tensor=None):
class MedNeXtUpBlock(MedNeXtBlock):

def __init__(
self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False
self,
in_channels: int,
out_channels: int,
expansion_ratio: int = 4,
kernel_size: int = 7,
use_residual_connection: bool = False,
norm_type: str = "group",
dim: str = "3d",
grn: bool = False,
):
super().__init__(
in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn
in_channels,
out_channels,
expansion_ratio,
kernel_size,
use_residual_connection=False,
norm_type=norm_type,
dim=dim,
grn=grn,
)

self.resample_do_res = do_res
self.resample_do_res = use_residual_connection

self.dim = dim
if dim == "2d":
conv = nn.ConvTranspose2d
else:
conv = nn.ConvTranspose3d
if do_res:
if use_residual_connection:
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)

self.conv1 = conv(
Expand All @@ -169,7 +202,7 @@ def __init__(
groups=in_channels,
)

def forward(self, x, dummy_tensor=None):
def forward(self, x):

x1 = super().forward(x)
# Asymmetry but necessary to match shape
Expand All @@ -190,7 +223,7 @@ def forward(self, x, dummy_tensor=None):
return x1


class OutBlock(nn.Module):
class MedNeXtOutBlock(nn.Module):

def __init__(self, in_channels, n_classes, dim):
super().__init__()
Expand All @@ -201,33 +234,5 @@ def __init__(self, in_channels, n_classes, dim):
conv = nn.ConvTranspose3d
self.conv_out = conv(in_channels, n_classes, kernel_size=1)

def forward(self, x, dummy_tensor=None):
def forward(self, x):
return self.conv_out(x)


class LayerNorm(nn.Module):
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""

def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta
self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)

def forward(self, x, dummy_tensor=False):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
return x
20 changes: 19 additions & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,25 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .mednext import MedNeXt
from .mednext import (
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rcremese I would prefer to do a factory method pattern similar to how it exists for resnets: https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/monai/networks/nets/resnet.py#L36

for consistency

Copy link
Author

@rcremese rcremese Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I took exemple on densnet and SEnet which are both integrated into the core library.

Moreover, there are only 4 architecture variants in the original paper so it shouldn't polute the script that much.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the inconsistency now.

Factory methods are better in this case I would say - it keeps them simpler as we are not using any functionality provided by subclassing the MedNext parent. Moreover, if we want to add custom logic like loading some pre-trained weights, these lend better too.

MedNeXt,
MedNext,
MedNextB,
MedNeXtB,
MedNextBase,
MedNextL,
MedNeXtL,
MedNeXtLarge,
MedNextLarge,
MedNextM,
MedNeXtM,
MedNeXtMedium,
MedNextMedium,
MedNextS,
MedNeXtS,
MedNeXtSmall,
MedNextSmall,
)
from .milmodel import MILModel
from .netadapter import NetAdapter
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
Expand Down
Loading
Loading