From e1517a1d4dd17687296b5b506becdc778c796b16 Mon Sep 17 00:00:00 2001 From: thibaultdvx <154365476+thibaultdvx@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:33:33 +0200 Subject: [PATCH] Clinicadl models (#639) * add nn module * unittests --- clinicadl/nn/__init__.py | 0 clinicadl/nn/blocks/__init__.py | 5 + clinicadl/nn/blocks/decoder.py | 185 ++++++ clinicadl/nn/blocks/encoder.py | 169 ++++++ clinicadl/nn/blocks/residual.py | 40 ++ clinicadl/nn/blocks/se.py | 82 +++ clinicadl/nn/blocks/unet.py | 71 +++ clinicadl/nn/layers/__init__.py | 5 + clinicadl/nn/layers/factory/__init__.py | 3 + clinicadl/nn/layers/factory/conv.py | 28 + clinicadl/nn/layers/factory/norm.py | 105 ++++ clinicadl/nn/layers/factory/pool.py | 81 +++ clinicadl/nn/layers/pool.py | 81 +++ clinicadl/nn/layers/reverse.py | 30 + clinicadl/nn/layers/unflatten.py | 35 ++ clinicadl/nn/layers/unpool.py | 32 + clinicadl/nn/networks/__init__.py | 21 + clinicadl/nn/networks/ae.py | 147 +++++ clinicadl/nn/networks/cnn.py | 288 +++++++++ clinicadl/nn/networks/factory/__init__.py | 3 + clinicadl/nn/networks/factory/ae.py | 142 +++++ clinicadl/nn/networks/factory/resnet.py | 119 ++++ clinicadl/nn/networks/factory/secnn.py | 61 ++ clinicadl/nn/networks/random.py | 222 +++++++ clinicadl/nn/networks/ssda.py | 111 ++++ clinicadl/nn/networks/unet.py | 39 ++ clinicadl/nn/networks/vae.py | 566 ++++++++++++++++++ clinicadl/nn/utils.py | 74 +++ clinicadl/utils/enum.py | 11 + tests/unittests/nn/__init__.py | 0 tests/unittests/nn/blocks/__init__.py | 0 tests/unittests/nn/blocks/test_decoder.py | 59 ++ tests/unittests/nn/blocks/test_encoder.py | 46 ++ tests/unittests/nn/blocks/test_residual.py | 9 + tests/unittests/nn/blocks/test_se.py | 33 + tests/unittests/nn/blocks/test_unet.py | 36 ++ tests/unittests/nn/layers/__init__.py | 0 tests/unittests/nn/layers/factory/__init__.py | 0 .../nn/layers/factory/test_factories.py | 27 + tests/unittests/nn/layers/test_layers.py | 101 ++++ tests/unittests/nn/networks/__init__.py | 0 .../unittests/nn/networks/factory/__init__.py | 0 .../nn/networks/factory/test_ae_factory.py | 68 +++ .../networks/factory/test_resnet_factory.py | 70 +++ .../nn/networks/factory/test_secnn_factory.py | 30 + tests/unittests/nn/networks/test_ae.py | 25 + tests/unittests/nn/networks/test_cnn.py | 32 + tests/unittests/nn/networks/test_ssda.py | 11 + tests/unittests/nn/networks/test_unet.py | 9 + tests/unittests/nn/networks/test_vae.py | 87 +++ tests/unittests/nn/test_utils.py | 49 ++ 51 files changed, 3448 insertions(+) create mode 100644 clinicadl/nn/__init__.py create mode 100644 clinicadl/nn/blocks/__init__.py create mode 100644 clinicadl/nn/blocks/decoder.py create mode 100644 clinicadl/nn/blocks/encoder.py create mode 100644 clinicadl/nn/blocks/residual.py create mode 100644 clinicadl/nn/blocks/se.py create mode 100644 clinicadl/nn/blocks/unet.py create mode 100644 clinicadl/nn/layers/__init__.py create mode 100644 clinicadl/nn/layers/factory/__init__.py create mode 100644 clinicadl/nn/layers/factory/conv.py create mode 100644 clinicadl/nn/layers/factory/norm.py create mode 100644 clinicadl/nn/layers/factory/pool.py create mode 100644 clinicadl/nn/layers/pool.py create mode 100644 clinicadl/nn/layers/reverse.py create mode 100644 clinicadl/nn/layers/unflatten.py create mode 100644 clinicadl/nn/layers/unpool.py create mode 100644 clinicadl/nn/networks/__init__.py create mode 100644 clinicadl/nn/networks/ae.py create mode 100644 clinicadl/nn/networks/cnn.py create mode 100644 clinicadl/nn/networks/factory/__init__.py create mode 100644 clinicadl/nn/networks/factory/ae.py create mode 100644 clinicadl/nn/networks/factory/resnet.py create mode 100644 clinicadl/nn/networks/factory/secnn.py create mode 100644 clinicadl/nn/networks/random.py create mode 100644 clinicadl/nn/networks/ssda.py create mode 100644 clinicadl/nn/networks/unet.py create mode 100644 clinicadl/nn/networks/vae.py create mode 100644 clinicadl/nn/utils.py create mode 100644 tests/unittests/nn/__init__.py create mode 100644 tests/unittests/nn/blocks/__init__.py create mode 100644 tests/unittests/nn/blocks/test_decoder.py create mode 100644 tests/unittests/nn/blocks/test_encoder.py create mode 100644 tests/unittests/nn/blocks/test_residual.py create mode 100644 tests/unittests/nn/blocks/test_se.py create mode 100644 tests/unittests/nn/blocks/test_unet.py create mode 100644 tests/unittests/nn/layers/__init__.py create mode 100644 tests/unittests/nn/layers/factory/__init__.py create mode 100644 tests/unittests/nn/layers/factory/test_factories.py create mode 100644 tests/unittests/nn/layers/test_layers.py create mode 100644 tests/unittests/nn/networks/__init__.py create mode 100644 tests/unittests/nn/networks/factory/__init__.py create mode 100644 tests/unittests/nn/networks/factory/test_ae_factory.py create mode 100644 tests/unittests/nn/networks/factory/test_resnet_factory.py create mode 100644 tests/unittests/nn/networks/factory/test_secnn_factory.py create mode 100644 tests/unittests/nn/networks/test_ae.py create mode 100644 tests/unittests/nn/networks/test_cnn.py create mode 100644 tests/unittests/nn/networks/test_ssda.py create mode 100644 tests/unittests/nn/networks/test_unet.py create mode 100644 tests/unittests/nn/networks/test_vae.py create mode 100644 tests/unittests/nn/test_utils.py diff --git a/clinicadl/nn/__init__.py b/clinicadl/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/nn/blocks/__init__.py b/clinicadl/nn/blocks/__init__.py new file mode 100644 index 000000000..1f15bafb2 --- /dev/null +++ b/clinicadl/nn/blocks/__init__.py @@ -0,0 +1,5 @@ +from .decoder import Decoder2D, Decoder3D, VAE_Decoder2D +from .encoder import Encoder2D, Encoder3D, VAE_Encoder2D +from .residual import ResBlock +from .se import ResBlock_SE, SE_Block +from .unet import UNetDown, UNetFinalLayer, UNetUp diff --git a/clinicadl/nn/blocks/decoder.py b/clinicadl/nn/blocks/decoder.py new file mode 100644 index 000000000..27938c8d7 --- /dev/null +++ b/clinicadl/nn/blocks/decoder.py @@ -0,0 +1,185 @@ +import torch.nn as nn +import torch.nn.functional as F + +from clinicadl.nn.layers import Unflatten2D, get_norm_layer + +__all__ = [ + "Decoder2D", + "Decoder3D", + "VAE_Decoder2D", +] + + +class Decoder2D(nn.Module): + """ + Class defining the decoder's part of the Autoencoder. + This layer is composed of one 2D transposed convolutional layer, + a batch normalization layer with a relu activation function. + """ + + def __init__( + self, + input_channels, + output_channels, + kernel_size=4, + stride=2, + padding=1, + output_padding=0, + normalization="BatchNorm", + ): + super(Decoder2D, self).__init__() + self.layer = nn.Sequential( + nn.ConvTranspose2d( + input_channels, + output_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=False, + ), + get_norm_layer(normalization, dim=2)(output_channels), + ) + + def forward(self, x): + x = F.relu(self.layer(x), inplace=True) + return x + + +class Decoder3D(nn.Module): + """ + Class defining the decoder's part of the Autoencoder. + This layer is composed of one 3D transposed convolutional layer, + a batch normalization layer with a relu activation function. + """ + + def __init__( + self, + input_channels, + output_channels, + kernel_size=4, + stride=2, + padding=1, + output_padding=0, + normalization="BatchNorm", + ): + super(Decoder3D, self).__init__() + self.layer = nn.Sequential( + nn.ConvTranspose3d( + input_channels, + output_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=False, + ), + get_norm_layer(normalization, dim=3)(output_channels), + ) + + def forward(self, x): + x = F.relu(self.layer(x), inplace=True) + return x + + +class VAE_Decoder2D(nn.Module): + def __init__( + self, + input_shape, + latent_size, + n_conv=4, + last_layer_channels=32, + latent_dim=1, + feature_size=1024, + padding=None, + ): + """ + Feature size is the size of the vector if latent_dim=1 + or is the W/H of the output channels if laten_dim=2 + """ + super(VAE_Decoder2D, self).__init__() + + self.input_c = input_shape[0] + self.input_h = input_shape[1] + self.input_w = input_shape[2] + + if not padding: + output_padding = [[0, 0] for _ in range(n_conv)] + else: + output_padding = padding + + self.layers = [] + + if latent_dim == 1: + n_pix = ( + last_layer_channels + * 2 ** (n_conv - 1) + * (self.input_h // (2**n_conv)) + * (self.input_w // (2**n_conv)) + ) + self.layers.append( + nn.Sequential( + nn.Linear(latent_size, feature_size), + nn.ReLU(), + nn.Linear(feature_size, n_pix), + nn.ReLU(), + Unflatten2D( + last_layer_channels * 2 ** (n_conv - 1), + self.input_h // (2**n_conv), + self.input_w // (2**n_conv), + ), + nn.ReLU(), + ) + ) + elif latent_dim == 2: + self.layers.append( + nn.Sequential( + nn.ConvTranspose2d( + latent_size, feature_size, 3, stride=1, padding=1, bias=False + ), + nn.ReLU(), + nn.ConvTranspose2d( + feature_size, + last_layer_channels * 2 ** (n_conv - 1), + 3, + stride=1, + padding=1, + bias=False, + ), + nn.ReLU(), + ) + ) + else: + raise AttributeError( + "Bad latent dimension specified. Latent dimension must be 1 or 2" + ) + + for i in range(n_conv - 1, 0, -1): + self.layers.append( + Decoder2D( + last_layer_channels * 2 ** (i), + last_layer_channels * 2 ** (i - 1), + output_padding=output_padding[i], + ) + ) + + self.layers.append( + nn.Sequential( + nn.ConvTranspose2d( + last_layer_channels, + self.input_c, + 4, + stride=2, + padding=1, + output_padding=output_padding[0], + bias=False, + ), + nn.Sigmoid(), + ) + ) + + self.sequential = nn.Sequential(*self.layers) + + def forward(self, z): + y = self.sequential(z) + return y diff --git a/clinicadl/nn/blocks/encoder.py b/clinicadl/nn/blocks/encoder.py new file mode 100644 index 000000000..fde13b956 --- /dev/null +++ b/clinicadl/nn/blocks/encoder.py @@ -0,0 +1,169 @@ +import torch.nn as nn +import torch.nn.functional as F + +from clinicadl.nn.layers import get_norm_layer + +__all__ = [ + "Encoder2D", + "Encoder3D", + "VAE_Encoder2D", +] + + +class Encoder2D(nn.Module): + """ + Class defining the encoder's part of the Autoencoder. + This layer is composed of one 2D convolutional layer, + a batch normalization layer with a leaky relu + activation function. + """ + + def __init__( + self, + input_channels, + output_channels, + kernel_size=4, + stride=2, + padding=1, + normalization="BatchNorm", + ): + super(Encoder2D, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d( + input_channels, + output_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ), + get_norm_layer(normalization, dim=2)( + output_channels + ), # TODO : will raise an error if GroupNorm + ) + + def forward(self, x): + x = F.leaky_relu(self.layer(x), negative_slope=0.2, inplace=True) + return x + + +class Encoder3D(nn.Module): + """ + Class defining the encoder's part of the Autoencoder. + This layer is composed of one 3D convolutional layer, + a batch normalization layer with a leaky relu + activation function. + """ + + def __init__( + self, + input_channels, + output_channels, + kernel_size=4, + stride=2, + padding=1, + normalization="BatchNorm", + ): + super(Encoder3D, self).__init__() + self.layer = nn.Sequential( + nn.Conv3d( + input_channels, + output_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ), + get_norm_layer(normalization, dim=3)(output_channels), + ) + + def forward(self, x): + x = F.leaky_relu(self.layer(x), negative_slope=0.2, inplace=True) + return x + + +class VAE_Encoder2D(nn.Module): + def __init__( + self, + input_shape, + n_conv=4, + first_layer_channels=32, + latent_dim=1, + feature_size=1024, + ): + """ + Feature size is the size of the vector if latent_dim=1 + or is the number of feature maps (number of channels) if latent_dim=2 + """ + super(VAE_Encoder2D, self).__init__() + + self.input_c = input_shape[0] + self.input_h = input_shape[1] + self.input_w = input_shape[2] + + decoder_padding = [] + tensor_h, tensor_w = self.input_h, self.input_w + + self.layers = [] + + # Input Layer + self.layers.append(Encoder2D(self.input_c, first_layer_channels)) + padding_h, padding_w = 0, 0 + if tensor_h % 2 != 0: + padding_h = 1 + if tensor_w % 2 != 0: + padding_w = 1 + decoder_padding.append([padding_h, padding_w]) + tensor_h, tensor_w = tensor_h // 2, tensor_w // 2 + # Conv Layers + for i in range(n_conv - 1): + self.layers.append( + Encoder2D( + first_layer_channels * 2**i, first_layer_channels * 2 ** (i + 1) + ) + ) + padding_h, padding_w = 0, 0 + if tensor_h % 2 != 0: + padding_h = 1 + if tensor_w % 2 != 0: + padding_w = 1 + decoder_padding.append([padding_h, padding_w]) + tensor_h, tensor_w = tensor_h // 2, tensor_w // 2 + + self.decoder_padding = decoder_padding + + # Final Layer + if latent_dim == 1: + n_pix = ( + first_layer_channels + * 2 ** (n_conv - 1) + * (self.input_h // (2**n_conv)) + * (self.input_w // (2**n_conv)) + ) + self.layers.append( + nn.Sequential(nn.Flatten(), nn.Linear(n_pix, feature_size), nn.ReLU()) + ) + elif latent_dim == 2: + self.layers.append( + nn.Sequential( + nn.Conv2d( + first_layer_channels * 2 ** (n_conv - 1), + feature_size, + 3, + stride=1, + padding=1, + bias=False, + ), + nn.ReLU(), + ) + ) + else: + raise AttributeError( + "Bad latent dimension specified. Latent dimension must be 1 or 2" + ) + + self.sequential = nn.Sequential(*self.layers) + + def forward(self, x): + z = self.sequential(x) + return z diff --git a/clinicadl/nn/blocks/residual.py b/clinicadl/nn/blocks/residual.py new file mode 100644 index 000000000..ec0a07316 --- /dev/null +++ b/clinicadl/nn/blocks/residual.py @@ -0,0 +1,40 @@ +import torch.nn as nn + + +class ResBlock(nn.Module): + def __init__(self, block_number, input_size): + super(ResBlock, self).__init__() + + layer_in = input_size if input_size is not None else 2 ** (block_number + 1) + layer_out = 2 ** (block_number + 2) + + self.conv1 = nn.Conv3d( + layer_in, layer_out, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm3d(layer_out) + self.act1 = nn.ELU() + + self.conv2 = nn.Conv3d( + layer_out, layer_out, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm3d(layer_out) + + # shortcut + self.shortcut = nn.Sequential( + nn.Conv3d( + layer_in, layer_out, kernel_size=1, stride=1, padding=0, bias=False + ) + ) + + self.act2 = nn.ELU() + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out += self.shortcut(x) + out = self.act2(out) + return out diff --git a/clinicadl/nn/blocks/se.py b/clinicadl/nn/blocks/se.py new file mode 100644 index 000000000..f406b7a92 --- /dev/null +++ b/clinicadl/nn/blocks/se.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn + + +class SE_Block(nn.Module): + def __init__(self, num_channels, ratio_channel): + super().__init__() + self.num_channels = num_channels + self.avg_pooling_3D = nn.AdaptiveAvgPool3d(1) + num_channels_reduced = num_channels // ratio_channel + self.fc1 = nn.Linear(num_channels, num_channels_reduced) + self.fc2 = nn.Linear(num_channels_reduced, num_channels) + self.act1 = nn.ReLU() + self.act2 = nn.Sigmoid() + + def forward(self, input_tensor): + """ + Parameters + ---------- + input_tensor: pt tensor + X, shape = (batch_size, num_channels, D, H, W) + + Returns + ------- + output_tensor: pt tensor + """ + batch_size, num_channels, D, H, W = input_tensor.size() + + # Average along each channel + squeeze_tensor = self.avg_pooling_3D(input_tensor) + + # channel excitation + fc_out_1 = self.act1(self.fc1(squeeze_tensor.view(batch_size, num_channels))) + fc_out_2 = self.act2(self.fc2(fc_out_1)) + + output_tensor = torch.mul( + input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1) + ) + + return output_tensor + + +class ResBlock_SE(nn.Module): + def __init__(self, block_number, input_size, num_channels, ratio_channel=8): + super(ResBlock_SE, self).__init__() + + layer_in = input_size if input_size is not None else 2 ** (block_number + 1) + layer_out = 2 ** (block_number + 2) + + self.conv1 = nn.Conv3d( + layer_in, layer_out, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm3d(layer_out) + self.act1 = nn.ReLU() + + self.conv2 = nn.Conv3d( + layer_out, layer_out, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm3d(layer_out) + + self.se_block = SE_Block(layer_out, ratio_channel) + + # shortcut + self.shortcut = nn.Sequential( + nn.Conv3d( + layer_in, layer_out, kernel_size=1, stride=1, padding=0, bias=False + ) + ) + + self.act2 = nn.ReLU() + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se_block(out) + out += self.shortcut(x) + out = self.act2(out) + return out diff --git a/clinicadl/nn/blocks/unet.py b/clinicadl/nn/blocks/unet.py new file mode 100644 index 000000000..4ca275cbd --- /dev/null +++ b/clinicadl/nn/blocks/unet.py @@ -0,0 +1,71 @@ +import torch +from torch import nn + + +class UNetDown(nn.Module): + """Descending block of the U-Net. + + Args: + in_size: (int) number of channels in the input image. + out_size : (int) number of channels in the output image. + + """ + + def __init__(self, in_size, out_size): + super(UNetDown, self).__init__() + self.model = nn.Sequential( + nn.Conv3d(in_size, out_size, kernel_size=3, stride=2, padding=1), + nn.InstanceNorm3d(out_size), + nn.LeakyReLU(0.2), + ) + + def forward(self, x): + return self.model(x) + + +class UNetUp(nn.Module): + """Ascending block of the U-Net. + + Args: + in_size: (int) number of channels in the input image. + out_size : (int) number of channels in the output image. + + """ + + def __init__(self, in_size, out_size): + super(UNetUp, self).__init__() + self.model = nn.Sequential( + nn.ConvTranspose3d(in_size, out_size, kernel_size=4, stride=2, padding=1), + nn.InstanceNorm3d(out_size), + nn.ReLU(inplace=True), + ) + + def forward(self, x, skip_input=None): + if skip_input is not None: + x = torch.cat((x, skip_input), 1) # add the skip connection + x = self.model(x) + return x + + +class UNetFinalLayer(nn.Module): + """Final block of the U-Net. + + Args: + in_size: (int) number of channels in the input image. + out_size : (int) number of channels in the output image. + + """ + + def __init__(self, in_size, out_size): + super(UNetFinalLayer, self).__init__() + self.model = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv3d(in_size, out_size, kernel_size=3, padding=1), + nn.Tanh(), + ) + + def forward(self, x, skip_input=None): + if skip_input is not None: + x = torch.cat((x, skip_input), 1) # add the skip connection + x = self.model(x) + return x diff --git a/clinicadl/nn/layers/__init__.py b/clinicadl/nn/layers/__init__.py new file mode 100644 index 000000000..00c3ed1ae --- /dev/null +++ b/clinicadl/nn/layers/__init__.py @@ -0,0 +1,5 @@ +from .factory import get_conv_layer, get_norm_layer, get_pool_layer +from .pool import PadMaxPool2d, PadMaxPool3d +from .reverse import GradientReversal +from .unflatten import Reshape, Unflatten2D, Unflatten3D +from .unpool import CropMaxUnpool2d, CropMaxUnpool3d diff --git a/clinicadl/nn/layers/factory/__init__.py b/clinicadl/nn/layers/factory/__init__.py new file mode 100644 index 000000000..55988c334 --- /dev/null +++ b/clinicadl/nn/layers/factory/__init__.py @@ -0,0 +1,3 @@ +from .conv import get_conv_layer +from .norm import get_norm_layer +from .pool import get_pool_layer diff --git a/clinicadl/nn/layers/factory/conv.py b/clinicadl/nn/layers/factory/conv.py new file mode 100644 index 000000000..3a789da61 --- /dev/null +++ b/clinicadl/nn/layers/factory/conv.py @@ -0,0 +1,28 @@ +from typing import Type, Union + +import torch.nn as nn + + +def get_conv_layer(dim: int) -> Union[Type[nn.Conv2d], Type[nn.Conv3d]]: + """ + A factory function for creating Convolutional layers. + + Parameters + ---------- + dim : int + Dimension of the image. + + Returns + ------- + Type[nn.Module] + The Convolutional layer. + + Raises + ------ + AssertionError + If dim is not 2 or 3. + """ + assert dim in {2, 3}, "Input dimension must be 2 or 3." + + layers = (nn.Conv2d, nn.Conv3d) + return layers[dim - 2] diff --git a/clinicadl/nn/layers/factory/norm.py b/clinicadl/nn/layers/factory/norm.py new file mode 100644 index 000000000..a95022924 --- /dev/null +++ b/clinicadl/nn/layers/factory/norm.py @@ -0,0 +1,105 @@ +from typing import Type, Union + +import torch.nn as nn + +from clinicadl.utils.enum import BaseEnum + + +class Normalization(str, BaseEnum): # TODO : remove from global enum + """Available normalization layers in ClinicaDL.""" + + BATCH = "BatchNorm" + GROUP = "GroupNorm" + INSTANCE = "InstanceNorm" + + +def get_norm_layer( + normalization: Union[str, Normalization], dim: int +) -> Type[nn.Module]: + """ + A factory function for creating Normalization layers. + + Parameters + ---------- + normalization : Normalization + Type of normalization. + dim : int + Dimension of the image. + + Returns + ------- + Type[nn.Module] + The normalization layer. + + Raises + ------ + AssertionError + If dim is not 2 or 3. + """ + assert dim in {2, 3}, "Input dimension must be 2 or 3." + normalization = Normalization(normalization) + + if normalization == Normalization.BATCH: + factory = _batch_norm_factory + elif normalization == Normalization.INSTANCE: + factory = _instance_norm_factory + elif normalization == Normalization.GROUP: + factory = _group_norm_factory + return factory(dim) + + +def _instance_norm_factory( + dim: int, +) -> Union[Type[nn.InstanceNorm2d], Type[nn.InstanceNorm3d]]: + """ + A factory function for creating Instance Normalization layers. + + Parameters + ---------- + dim : int + Dimension of the image. + + Returns + ------- + Union[Type[nn.InstanceNorm2d], Type[nn.InstanceNorm3d]] + The normalization layer. + """ + layers = (nn.InstanceNorm2d, nn.InstanceNorm3d) + return layers[dim - 2] + + +def _batch_norm_factory(dim: int) -> Union[Type[nn.BatchNorm2d], Type[nn.BatchNorm3d]]: + """ + A factory function for creating Batch Normalization layers. + + Parameters + ---------- + dim : int + Dimension of the image. + + Returns + ------- + Union[Type[nn.BatchNorm2d], Type[nn.BatchNorm3d]] + The normalization layer. + """ + layers = (nn.BatchNorm2d, nn.BatchNorm3d) + return layers[dim - 2] + + +def _group_norm_factory(dim: int) -> Type[nn.GroupNorm]: + """ + A dummy function that returns a Group Normalization layer. + + To match other factory functions. + + Parameters + ---------- + dim : int + Dimension of the image. + + Returns + ------- + Type[nn.GroupNorm] + The normalization layer. + """ + return nn.GroupNorm diff --git a/clinicadl/nn/layers/factory/pool.py b/clinicadl/nn/layers/factory/pool.py new file mode 100644 index 000000000..b48cf7465 --- /dev/null +++ b/clinicadl/nn/layers/factory/pool.py @@ -0,0 +1,81 @@ +from typing import TYPE_CHECKING, Type, Union + +import torch.nn as nn + +from clinicadl.utils.enum import BaseEnum + +from ..pool import PadMaxPool2d, PadMaxPool3d + + +class Pooling(str, BaseEnum): + """Available pooling layers in ClinicaDL.""" + + MAX = "MaxPool" + PADMAX = "PadMaxPool" + + +def get_pool_layer(pooling: Union[str, Pooling], dim: int) -> Type[nn.Module]: + """ + A factory object for creating Pooling layers. + + Parameters + ---------- + pooling : Pooling + Type of pooling. + dim : int + Dimension of the image. + + Returns + ------- + Type[nn.Module] + The pooling layer. + + Raises + ------ + AssertionError + If dim is not 2 or 3. + """ + assert dim in {2, 3}, "Input dimension must be 2 or 3." + pooling = Pooling(pooling) + + if pooling == Pooling.MAX: + factory = _max_pool_factory + elif pooling == Pooling.PADMAX: + factory = _pad_max_pool_factory + return factory(dim) + + +def _max_pool_factory(dim: int) -> Union[Type[nn.MaxPool2d], Type[nn.MaxPool3d]]: + """ + A factory object for creating Max Pooling layers. + + Parameters + ---------- + dim : int + Dimension of the image. + + Returns + ------- + Union[Type[nn.MaxPool2d], Type[nn.MaxPool3d]] + The pooling layer. + """ + layers = (nn.MaxPool2d, nn.MaxPool3d) + return layers[dim - 2] + + +def _pad_max_pool_factory(dim: int) -> Union[Type[PadMaxPool2d], Type[PadMaxPool3d]]: + """ + A factory object for creating Pad-Max Pooling layers. + + Parameters + ---------- + dim : int + Dimension of the image. + + Returns + ------- + Union[Type[PadMaxPool2d], Type[PadMaxPool3d]] + The pooling layer. + """ + layers = (PadMaxPool2d, PadMaxPool3d) + return layers[dim - 2] diff --git a/clinicadl/nn/layers/pool.py b/clinicadl/nn/layers/pool.py new file mode 100644 index 000000000..92ff2d5dd --- /dev/null +++ b/clinicadl/nn/layers/pool.py @@ -0,0 +1,81 @@ +import torch.nn as nn + + +class PadMaxPool3d(nn.Module): + def __init__(self, kernel_size, stride, return_indices=False, return_pad=False): + super(PadMaxPool3d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.pool = nn.MaxPool3d(kernel_size, stride, return_indices=return_indices) + self.pad = nn.ConstantPad3d(padding=0, value=0) + self.return_indices = return_indices + self.return_pad = return_pad + + def set_new_return(self, return_indices=True, return_pad=True): + self.return_indices = return_indices + self.return_pad = return_pad + self.pool.return_indices = return_indices + + def forward(self, f_maps): + coords = [self.stride - f_maps.size(i + 2) % self.stride for i in range(3)] + for i, coord in enumerate(coords): + if coord == self.stride: + coords[i] = 0 + + self.pad.padding = (coords[2], 0, coords[1], 0, coords[0], 0) + + if self.return_indices: + output, indices = self.pool(self.pad(f_maps)) + + if self.return_pad: + return output, indices, (coords[2], 0, coords[1], 0, coords[0], 0) + else: + return output, indices + + else: + output = self.pool(self.pad(f_maps)) + + if self.return_pad: + return output, (coords[2], 0, coords[1], 0, coords[0], 0) + else: + return output + + +class PadMaxPool2d(nn.Module): + def __init__(self, kernel_size, stride, return_indices=False, return_pad=False): + super(PadMaxPool2d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.pool = nn.MaxPool2d(kernel_size, stride, return_indices=return_indices) + self.pad = nn.ConstantPad2d(padding=0, value=0) + self.return_indices = return_indices + self.return_pad = return_pad + + def set_new_return(self, return_indices=True, return_pad=True): + self.return_indices = return_indices + self.return_pad = return_pad + self.pool.return_indices = return_indices + + def forward(self, f_maps): + coords = [self.stride - f_maps.size(i + 2) % self.stride for i in range(2)] + for i, coord in enumerate(coords): + if coord == self.stride: + coords[i] = 0 + + self.pad.padding = (coords[1], 0, coords[0], 0) + + if self.return_indices: + output, indices = self.pool(self.pad(f_maps)) + + if self.return_pad: + return output, indices, (coords[1], 0, coords[0], 0) + else: + return output, indices + + else: + output = self.pool(self.pad(f_maps)) + + if self.return_pad: + return output, (coords[1], 0, coords[0], 0) + else: + return output diff --git a/clinicadl/nn/layers/reverse.py b/clinicadl/nn/layers/reverse.py new file mode 100644 index 000000000..d433ac47f --- /dev/null +++ b/clinicadl/nn/layers/reverse.py @@ -0,0 +1,30 @@ +import torch +from torch import nn +from torch.autograd import Function + + +class GradientReversalFunction(Function): + @staticmethod + def forward(ctx, x, alpha): + ctx.save_for_backward(x, alpha) + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + _, alpha = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_input = -alpha * grad_output + return grad_input, None + + +revgrad = GradientReversalFunction.apply + + +class GradientReversal(nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = torch.tensor(alpha, requires_grad=False) + + def forward(self, x): + return revgrad(x, self.alpha) diff --git a/clinicadl/nn/layers/unflatten.py b/clinicadl/nn/layers/unflatten.py new file mode 100644 index 000000000..75d4eb92f --- /dev/null +++ b/clinicadl/nn/layers/unflatten.py @@ -0,0 +1,35 @@ +import torch.nn as nn + + +class Unflatten2D(nn.Module): + def __init__(self, channel, height, width): + super(Unflatten2D, self).__init__() + self.channel = channel + self.height = height + self.width = width + + def forward(self, input): + return input.view(input.size(0), self.channel, self.height, self.width) + + +class Unflatten3D(nn.Module): + def __init__(self, channel, height, width, depth): + super(Unflatten3D, self).__init__() + self.channel = channel + self.height = height + self.width = width + self.depth = depth + + def forward(self, input): + return input.view( + input.size(0), self.channel, self.height, self.width, self.depth + ) + + +class Reshape(nn.Module): # TODO : redundant with Unflatten + def __init__(self, size): + super(Reshape, self).__init__() + self.size = size + + def forward(self, input): + return input.view(*self.size) diff --git a/clinicadl/nn/layers/unpool.py b/clinicadl/nn/layers/unpool.py new file mode 100644 index 000000000..90da20a3e --- /dev/null +++ b/clinicadl/nn/layers/unpool.py @@ -0,0 +1,32 @@ +import torch.nn as nn + + +class CropMaxUnpool3d(nn.Module): + def __init__(self, kernel_size, stride): + super(CropMaxUnpool3d, self).__init__() + self.unpool = nn.MaxUnpool3d(kernel_size, stride) + + def forward(self, f_maps, indices, padding=None): + output = self.unpool(f_maps, indices) + if padding is not None: + x1 = padding[4] + y1 = padding[2] + z1 = padding[0] + output = output[:, :, x1::, y1::, z1::] + + return output + + +class CropMaxUnpool2d(nn.Module): + def __init__(self, kernel_size, stride): + super(CropMaxUnpool2d, self).__init__() + self.unpool = nn.MaxUnpool2d(kernel_size, stride) + + def forward(self, f_maps, indices, padding=None): + output = self.unpool(f_maps, indices) + if padding is not None: + x1 = padding[2] + y1 = padding[0] + output = output[:, :, x1::, y1::] + + return output diff --git a/clinicadl/nn/networks/__init__.py b/clinicadl/nn/networks/__init__.py new file mode 100644 index 000000000..c77097e60 --- /dev/null +++ b/clinicadl/nn/networks/__init__.py @@ -0,0 +1,21 @@ +from .ae import AE_Conv4_FC3, AE_Conv5_FC3, CAE_half +from .cnn import ( + Conv4_FC3, + Conv5_FC3, + ResNet3D, + SqueezeExcitationCNN, + Stride_Conv5_FC3, + resnet18, +) +from .random import RandomArchitecture +from .ssda import Conv5_FC3_SSDA +from .unet import UNet +from .vae import ( + CVAE_3D, + CVAE_3D_final_conv, + CVAE_3D_half, + VanillaDenseVAE, + VanillaDenseVAE3D, + VanillaSpatialVAE, + VanillaSpatialVAE3D, +) diff --git a/clinicadl/nn/networks/ae.py b/clinicadl/nn/networks/ae.py new file mode 100644 index 000000000..1a8ed283f --- /dev/null +++ b/clinicadl/nn/networks/ae.py @@ -0,0 +1,147 @@ +import numpy as np +from torch import nn + +from clinicadl.nn.blocks import Decoder3D, Encoder3D +from clinicadl.nn.layers import ( + CropMaxUnpool2d, + CropMaxUnpool3d, + PadMaxPool2d, + PadMaxPool3d, + Unflatten3D, +) +from clinicadl.nn.networks.cnn import Conv4_FC3, Conv5_FC3 +from clinicadl.nn.networks.factory import autoencoder_from_cnn +from clinicadl.nn.utils import compute_output_size +from clinicadl.utils.enum import BaseEnum + + +class AE2d(str, BaseEnum): + """AutoEncoders compatible with 2D inputs.""" + + AE_CONV5_FC3 = "AE_Conv5_FC3" + AE_CONV4_FC3 = "AE_Conv4_FC3" + + +class AE3d(str, BaseEnum): + """AutoEncoders compatible with 3D inputs.""" + + AE_CONV5_FC3 = "AE_Conv5_FC3" + AE_CONV4_FC3 = "AE_Conv4_FC3" + CAE_HALF = "CAE_half" + + +class ImplementedAE(str, BaseEnum): + """Implemented AutoEncoders in ClinicaDL.""" + + AE_CONV5_FC3 = "AE_Conv5_FC3" + AE_CONV4_FC3 = "AE_Conv4_FC3" + CAE_HALF = "CAE_half" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented AutoEncoders are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +# Networks # +class AE(nn.Module): + """Base class for AutoEncoders.""" + + def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def encode(self, x): + indices_list = [] + pad_list = [] + for layer in self.encoder: + if ( + (isinstance(layer, PadMaxPool3d) or isinstance(layer, PadMaxPool2d)) + and layer.return_indices + and layer.return_pad + ): + x, indices, pad = layer(x) + indices_list.append(indices) + pad_list.append(pad) + elif ( + isinstance(layer, nn.MaxPool3d) or isinstance(layer, nn.MaxPool2d) + ) and layer.return_indices: + x, indices = layer(x) + indices_list.append(indices) + else: + x = layer(x) + return x, indices_list, pad_list + + def decode(self, x, indices_list=None, pad_list=None): + for layer in self.decoder: + if isinstance(layer, CropMaxUnpool3d) or isinstance(layer, CropMaxUnpool2d): + x = layer(x, indices_list.pop(), pad_list.pop()) + elif isinstance(layer, nn.MaxUnpool3d) or isinstance(layer, nn.MaxUnpool2d): + x = layer(x, indices_list.pop()) + else: + x = layer(x) + return x + + def forward( + self, x + ): # TODO : simplify and remove indices_list and pad_list (it is too complicated, there are lot of cases that can raise an issue) + encoded, indices_list, pad_list = self.encode(x) + return self.decode(encoded, indices_list, pad_list) + + +class AE_Conv5_FC3(AE): + """ + Autoencoder derived from the convolutional part of CNN Conv5_FC3. + """ + + def __init__(self, input_size, dropout): + cnn_model = Conv5_FC3( + input_size=input_size, output_size=1, dropout=dropout + ) # outputsize is not useful as we only take the convolutional part + encoder, decoder = autoencoder_from_cnn(cnn_model) + super().__init__(encoder, decoder) + + +class AE_Conv4_FC3(AE): + """ + Autoencoder derived from the convolutional part of CNN Conv4_FC3. + """ + + def __init__(self, input_size, dropout): + cnn_model = Conv4_FC3( + input_size=input_size, output_size=1, dropout=dropout + ) # outputsize is not useful as we only take the convolutional part + encoder, decoder = autoencoder_from_cnn(cnn_model) + super().__init__(encoder, decoder) + + +class CAE_half(AE): + """ + 3D Autoencoder derived from CVAE. + """ + + def __init__( + self, input_size, latent_space_size + ): # TODO: doesn't work for even inputs + encoder = nn.Sequential( + Encoder3D(1, 32, kernel_size=3), + Encoder3D(32, 64, kernel_size=3), + Encoder3D(64, 128, kernel_size=3), + ) + conv_output_shape = compute_output_size(input_size, encoder) + flattened_size = np.prod(conv_output_shape) + encoder.append(nn.Flatten()) + encoder.append(nn.Linear(flattened_size, latent_space_size)) + decoder = nn.Sequential( + nn.Linear(latent_space_size, flattened_size * 2), + Unflatten3D( + 256, conv_output_shape[1], conv_output_shape[2], conv_output_shape[3] + ), + Decoder3D(256, 128, kernel_size=3), + Decoder3D(128, 64, kernel_size=3), + Decoder3D(64, 1, kernel_size=3), + ) + super().__init__(encoder, decoder) diff --git a/clinicadl/nn/networks/cnn.py b/clinicadl/nn/networks/cnn.py new file mode 100644 index 000000000..eb2104b1e --- /dev/null +++ b/clinicadl/nn/networks/cnn.py @@ -0,0 +1,288 @@ +import numpy as np +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn +from torchvision.models.resnet import BasicBlock + +from clinicadl.nn.layers.factory import ( + get_conv_layer, + get_norm_layer, + get_pool_layer, +) +from clinicadl.utils.enum import BaseEnum + +from .factory import ResNetDesigner, ResNetDesigner3D, SECNNDesigner3D +from .factory.resnet import model_urls + + +class CNN2d(str, BaseEnum): + """CNNs compatible with 2D inputs.""" + + CONV5_FC3 = "Conv5_FC3" + CONV4_FC3 = "Conv4_FC3" + STRIDE_CONV5_FC3 = "Stride_Conv5_FC3" + RESNET = "resnet18" + + +class CNN3d(str, BaseEnum): + """CNNs compatible with 3D inputs.""" + + CONV5_FC3 = "Conv5_FC3" + CONV4_FC3 = "Conv4_FC3" + STRIDE_CONV5_FC3 = "Stride_Conv5_FC3" + RESNET3D = "ResNet3D" + SECNN = "SqueezeExcitationCNN" + + +class ImplementedCNN(str, BaseEnum): + """Implemented CNNs in ClinicaDL.""" + + CONV5_FC3 = "Conv5_FC3" + CONV4_FC3 = "Conv4_FC3" + STRIDE_CONV5_FC3 = "Stride_Conv5_FC3" + RESNET = "resnet18" + RESNET3D = "ResNet3D" + SECNN = "SqueezeExcitationCNN" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented CNNs are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +# Networks # +class CNN(nn.Module): + """Base class for CNN.""" + + def __init__(self, convolution_layers: nn.Module, fc_layers: nn.Module) -> None: + super().__init__() + self.convolutions = convolution_layers + self.fc = fc_layers + + def forward(self, x): + inter = self.convolutions(x) + print(self.convolutions) + print(inter.shape) + return self.fc(inter) + + +class Conv5_FC3(CNN): + """A Convolutional Neural Network with 5 convolution and 3 fully-connected layers.""" + + def __init__(self, input_size, output_size, dropout): + dim = len(input_size) - 1 + in_channels = input_size[0] + + conv = get_conv_layer(dim) + pool = get_pool_layer("PadMaxPool", dim=dim) + norm = get_norm_layer("BatchNorm", dim=dim) + + convolutions = nn.Sequential( + conv(in_channels, 8, 3, padding=1), + norm(8), + nn.ReLU(), + pool(2, 2), + conv(8, 16, 3, padding=1), + norm(16), + nn.ReLU(), + pool(2, 2), + conv(16, 32, 3, padding=1), + norm(32), + nn.ReLU(), + pool(2, 2), + conv(32, 64, 3, padding=1), + norm(64), + nn.ReLU(), + pool(2, 2), + conv(64, 128, 3, padding=1), + norm(128), + nn.ReLU(), + pool(2, 2), + ) + + input_tensor = torch.zeros(input_size).unsqueeze(0) + output_shape = convolutions(input_tensor).shape + + fc = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(np.prod(list(output_shape)).item(), 1300), + nn.ReLU(), + nn.Linear(1300, 50), + nn.ReLU(), + nn.Linear(50, output_size), + ) + super().__init__(convolutions, fc) + + +class Conv4_FC3(CNN): + """A Convolutional Neural Network with 4 convolution and 3 fully-connected layers.""" + + def __init__(self, input_size, output_size, dropout): + dim = len(input_size) - 1 + in_channels = input_size[0] + + conv = get_conv_layer(dim) + pool = get_pool_layer("PadMaxPool", dim=dim) + norm = get_norm_layer("BatchNorm", dim=dim) + + convolutions = nn.Sequential( + conv(in_channels, 8, 3, padding=1), + norm(8), + nn.ReLU(), + pool(2, 2), + conv(8, 16, 3, padding=1), + norm(16), + nn.ReLU(), + pool(2, 2), + conv(16, 32, 3, padding=1), + norm(32), + nn.ReLU(), + pool(2, 2), + conv(32, 64, 3, padding=1), + norm(64), + nn.ReLU(), + pool(2, 2), + conv(64, 128, 3, padding=1), + norm(128), + nn.ReLU(), + pool(2, 2), + ) + + input_tensor = torch.zeros(input_size).unsqueeze(0) + output_shape = convolutions(input_tensor).shape + + fc = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(np.prod(list(output_shape)).item(), 50), + nn.ReLU(), + nn.Linear(50, 40), + nn.ReLU(), + nn.Linear(40, output_size), + ) + super().__init__(convolutions, fc) + + +class Stride_Conv5_FC3(CNN): + """A Convolutional Neural Network with 5 convolution and 3 fully-connected layers and a stride of 2 for each convolutional layer.""" + + def __init__(self, input_size, output_size, dropout): + dim = len(input_size) - 1 + in_channels = input_size[0] + + conv = get_conv_layer(dim) + norm = get_norm_layer("BatchNorm", dim=dim) + + convolutions = nn.Sequential( + conv(in_channels, 8, 3, padding=1, stride=2), + norm(8), + nn.ReLU(), + conv(8, 16, 3, padding=1, stride=2), + norm(16), + nn.ReLU(), + conv(16, 32, 3, padding=1, stride=2), + norm(32), + nn.ReLU(), + conv(32, 64, 3, padding=1, stride=2), + norm(64), + nn.ReLU(), + conv(64, 128, 3, padding=1, stride=2), + norm(128), + nn.ReLU(), + ) + + input_tensor = torch.zeros(input_size).unsqueeze(0) + output_shape = convolutions(input_tensor).shape + + fc = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(np.prod(list(output_shape)).item(), 1300), + nn.ReLU(), + nn.Linear(1300, 50), + nn.ReLU(), + nn.Linear(50, output_size), + ) + super().__init__(convolutions, fc) + + +class resnet18(CNN): + """ + ResNet-18 is a neural network that is 18 layers deep based on residual block. + It uses skip connections or shortcuts to jump over some layers. + It is an image classification pre-trained model. + The model input has 3 channels in RGB order. + + Reference: Kaiming He et al., Deep Residual Learning for Image Recognition. + https://arxiv.org/abs/1512.03385?context=cs + """ + + def __init__(self, input_size, output_size, dropout): + model = ResNetDesigner(input_size, BasicBlock, [2, 2, 2, 2]) + model.load_state_dict(model_zoo.load_url(model_urls["resnet18"])) + + convolutions = nn.Sequential( + model.conv1, + model.bn1, + model.relu, + model.maxpool, + model.layer1, + model.layer2, + model.layer3, + model.layer4, + model.avgpool, + ) + + # add a fc layer on top of the transfer_learning model and a softmax classifier + fc = nn.Sequential(nn.Flatten(), model.fc) + fc.add_module("drop_out", nn.Dropout(p=dropout)) + fc.add_module("fc_out", nn.Linear(1000, output_size)) + + super().__init__(convolutions, fc) + + +class ResNet3D(CNN): + """ + ResNet3D is a 3D neural network composed of 5 residual blocks. Each residual block + is compose of 3D convolutions followed by a batch normalization and an activation function. + It uses skip connections or shortcuts to jump over some layers. It's a 3D version of the + original implementation of Kaiming He et al. + + Reference: Kaiming He et al., Deep Residual Learning for Image Recognition. + https://arxiv.org/abs/1512.03385?context=cs + """ + + def __init__(self, input_size, output_size, dropout): + model = ResNetDesigner3D(input_size, output_size, dropout) + convolutions = nn.Sequential( + model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 + ) + fc_layers = model.fc + super().__init__(convolutions, fc_layers) + + +class SqueezeExcitationCNN(CNN): + """ + SE-CNN is a combination of a ResNet-101 with Squeeze and Excitation blocks which was successfully + tested on brain tumour classification by Ghosal et al. 2019. SE blocks are composed of a squeeze + and an excitation step. The squeeze operation is obtained through an average pooling layer and + provides a global understanding of each channel. + + The excitation part consists of a two-layer feed-forward network that outputs a vector of n values + corresponding to the weights of each channel of the feature maps. + + Reference: Ghosal et al. Brain Tumor Classification Using ResNet-101 Based Squeeze and Excitation Deep Neural Network + https://ieeexplore.ieee.org/document/8882973 + + """ + + def __init__(self, input_size, output_size, dropout): + model = SECNNDesigner3D(input_size, output_size, dropout) + convolutions = nn.Sequential( + model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 + ) + fc_layers = model.fc + super().__init__(convolutions, fc_layers) diff --git a/clinicadl/nn/networks/factory/__init__.py b/clinicadl/nn/networks/factory/__init__.py new file mode 100644 index 000000000..85e6303c0 --- /dev/null +++ b/clinicadl/nn/networks/factory/__init__.py @@ -0,0 +1,3 @@ +from .ae import autoencoder_from_cnn +from .resnet import ResNetDesigner, ResNetDesigner3D +from .secnn import SECNNDesigner3D diff --git a/clinicadl/nn/networks/factory/ae.py b/clinicadl/nn/networks/factory/ae.py new file mode 100644 index 000000000..fccb14484 --- /dev/null +++ b/clinicadl/nn/networks/factory/ae.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING, List, Tuple + +from torch import nn + +from clinicadl.nn.layers import ( + CropMaxUnpool2d, + CropMaxUnpool3d, + PadMaxPool2d, + PadMaxPool3d, +) + +if TYPE_CHECKING: + from clinicadl.nn.networks.cnn import CNN + + +def autoencoder_from_cnn(model: CNN) -> Tuple[nn.Module, nn.Module]: + """ + Constructs an autoencoder from a given CNN. + + The encoder part corresponds to the convolutional part of the CNN. + The decoder part is the symmetrical network of the encoder. + + Parameters + ---------- + model : CNN + The input CNN model + + Returns + ------- + Tuple[nn.Module, nn.Module] + The encoder and the decoder. + """ + + encoder = deepcopy(model.convolutions) + decoder = _construct_inv_cnn(encoder) + + for i, layer in enumerate(encoder): + if isinstance(layer, PadMaxPool3d) or isinstance(layer, PadMaxPool2d): + encoder[i].set_new_return() + elif isinstance(layer, nn.MaxPool3d) or isinstance(layer, nn.MaxPool2d): + encoder[i].return_indices = True + + return encoder, decoder + + +def _construct_inv_cnn(model: nn.Module) -> nn.Module: + """ + Implements a decoder from an CNN encoder. + + The decoder part is the symmetrical list of the encoder + in which some layers are replaced by their transpose counterpart. + ConvTranspose and ReLU layers are also inverted. + + Parameters + ---------- + model : nn.Module + The input CNN encoder. + + Returns + ------- + nn.Module + The symmetrical CNN decoder. + """ + inv_layers = [] + for layer in model: + if isinstance(layer, nn.Conv3d): + inv_layers.append( + nn.ConvTranspose3d( + layer.out_channels, + layer.in_channels, + layer.kernel_size, + stride=layer.stride, + padding=layer.padding, + ) + ) + elif isinstance(layer, nn.Conv2d): + inv_layers.append( + nn.ConvTranspose2d( + layer.out_channels, + layer.in_channels, + layer.kernel_size, + stride=layer.stride, + padding=layer.padding, + ) + ) + elif isinstance(layer, PadMaxPool3d): + inv_layers.append(CropMaxUnpool3d(layer.kernel_size, stride=layer.stride)) + elif isinstance(layer, PadMaxPool2d): + inv_layers.append(CropMaxUnpool2d(layer.kernel_size, stride=layer.stride)) + elif isinstance(layer, nn.LeakyReLU): + inv_layers.append(nn.LeakyReLU(negative_slope=1 / layer.negative_slope)) + else: + inv_layers.append(deepcopy(layer)) + inv_layers = _invert_conv_and_relu(inv_layers) + inv_layers.reverse() + + return nn.Sequential(*inv_layers) + + +def _invert_conv_and_relu(inv_layers: List[nn.Module]) -> List[nn.Module]: + """ + Invert convolutional and ReLU layers (give empirical better results). + + Parameters + ---------- + inv_layers : List[nn.Module] + The list of layers. + + Returns + ------- + List[nn.Module] + The modified list of layers. + """ + idx_relu, idx_conv = -1, -1 + for idx, layer in enumerate(inv_layers): + if isinstance(layer, nn.ConvTranspose3d) or isinstance( + layer, nn.ConvTranspose2d + ): + idx_conv = idx + elif isinstance(layer, nn.ReLU) or isinstance(layer, nn.LeakyReLU): + idx_relu = idx + + if idx_conv != -1 and idx_relu != -1: + inv_layers[idx_relu], inv_layers[idx_conv] = ( + inv_layers[idx_conv], + inv_layers[idx_relu], + ) + idx_conv, idx_relu = -1, -1 + + # Check if number of features of batch normalization layers is still correct + for idx, layer in enumerate(inv_layers): + if isinstance(layer, nn.BatchNorm3d): + conv = inv_layers[idx + 1] + inv_layers[idx] = nn.BatchNorm3d(conv.out_channels) + elif isinstance(layer, nn.BatchNorm2d): + conv = inv_layers[idx + 1] + inv_layers[idx] = nn.BatchNorm2d(conv.out_channels) + + return inv_layers diff --git a/clinicadl/nn/networks/factory/resnet.py b/clinicadl/nn/networks/factory/resnet.py new file mode 100644 index 000000000..251199c92 --- /dev/null +++ b/clinicadl/nn/networks/factory/resnet.py @@ -0,0 +1,119 @@ +import math + +import torch +from torch import nn + +from clinicadl.nn.blocks import ResBlock + +model_urls = {"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth"} + + +class ResNetDesigner(nn.Module): + def __init__(self, input_size, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNetDesigner, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + # Compute avgpool size + input_tensor = torch.zeros(input_size).unsqueeze(0) + out = self.conv1(input_tensor) + out = self.relu(self.bn1(out)) + out = self.maxpool(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + + self.avgpool = nn.AvgPool2d((out.size(2), out.size(3)), stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + +class ResNetDesigner3D(nn.Module): + def __init__(self, input_size, output_size, dropout): + super(ResNetDesigner3D, self).__init__() + + assert ( + len(input_size) == 4 + ), "Input must be in 3D with the corresponding number of channels." + + self.layer0 = self._make_block(1, input_size[0]) + self.layer1 = self._make_block(2) + self.layer2 = self._make_block(3) + self.layer3 = self._make_block(4) + self.layer4 = self._make_block(5) + + input_tensor = torch.zeros(input_size).unsqueeze(0) + out = self.layer0(input_tensor) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + + d, h, w = self._maxpool_output_size(input_size[1::], nb_layers=5) + self.fc = nn.Sequential( + nn.Flatten(), + nn.Linear(128 * d * h * w, 256), # t1 image + nn.ELU(), + nn.Dropout(p=dropout), + nn.Linear(256, output_size), + ) + + for layer in self.fc: + out = layer(out) + + def _make_block(self, block_number, input_size=None): + return nn.Sequential( + ResBlock(block_number, input_size), nn.MaxPool3d(3, stride=2) + ) + + def _maxpool_output_size( + self, input_size, kernel_size=(3, 3, 3), stride=(2, 2, 2), nb_layers=1 + ): + import math + + d = math.floor((input_size[0] - kernel_size[0]) / stride[0] + 1) + h = math.floor((input_size[1] - kernel_size[1]) / stride[1] + 1) + w = math.floor((input_size[2] - kernel_size[2]) / stride[2] + 1) + + if nb_layers == 1: + return d, h, w + return self._maxpool_output_size( + (d, h, w), kernel_size=kernel_size, stride=stride, nb_layers=nb_layers - 1 + ) diff --git a/clinicadl/nn/networks/factory/secnn.py b/clinicadl/nn/networks/factory/secnn.py new file mode 100644 index 000000000..270f0a357 --- /dev/null +++ b/clinicadl/nn/networks/factory/secnn.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +from clinicadl.nn.blocks import ResBlock_SE + + +class SECNNDesigner3D(nn.Module): + def __init__(self, input_size, output_size, dropout): + super(SECNNDesigner3D, self).__init__() + + assert ( + len(input_size) == 4 + ), "input must be in 3d with the corresponding number of channels" + + self.layer0 = self._make_block(1, 8, 8, input_size[0]) + self.layer1 = self._make_block(2, 16) + self.layer2 = self._make_block(3, 32) + self.layer3 = self._make_block(4, 64) + self.layer4 = self._make_block(5, 128) + + input_tensor = torch.zeros(input_size).unsqueeze(0) + out = self.layer0(input_tensor) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + + d, h, w = self._maxpool_output_size(input_size[1::], nb_layers=5) + self.fc = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(128 * d * h * w, 256), # t1 image + nn.ReLU(), + nn.Linear(256, output_size), + ) + + for layer in self.fc: + out = layer(out) + + def _make_block( + self, block_number, num_channels, ration_channel=8, input_size=None + ): + return nn.Sequential( + ResBlock_SE(block_number, input_size, num_channels, ration_channel), + nn.MaxPool3d(3, stride=2), + ) + + def _maxpool_output_size( + self, input_size, kernel_size=(3, 3, 3), stride=(2, 2, 2), nb_layers=1 + ): + import math + + d = math.floor((input_size[0] - kernel_size[0]) / stride[0] + 1) + h = math.floor((input_size[1] - kernel_size[1]) / stride[1] + 1) + w = math.floor((input_size[2] - kernel_size[2]) / stride[2] + 1) + + if nb_layers == 1: + return d, h, w + return self._maxpool_output_size( + (d, h, w), kernel_size=kernel_size, stride=stride, nb_layers=nb_layers - 1 + ) diff --git a/clinicadl/nn/networks/random.py b/clinicadl/nn/networks/random.py new file mode 100644 index 000000000..50b18dd60 --- /dev/null +++ b/clinicadl/nn/networks/random.py @@ -0,0 +1,222 @@ +import numpy as np +import torch.nn as nn + +from clinicadl.nn.layers import PadMaxPool2d, PadMaxPool3d +from clinicadl.nn.networks.cnn import CNN +from clinicadl.utils.exceptions import ClinicaDLNetworksError + + +class RandomArchitecture(CNN): # TODO : unabled to test it + def __init__( + self, + convolutions_dict, + n_fcblocks, + input_size, + dropout=0.5, + network_normalization="BatchNorm", + output_size=2, + ): + """ + Construct the Architecture randomly chosen for Random Search. + + Args: + convolutions_dict: (dict) description of the convolutional blocks. + n_fcblocks: (int) number of FC blocks in the network. + input_size: (list) gives the structure of the input of the network. + dropout: (float) rate of the dropout. + network_normalization: (str) type of normalization layer in the network. + output_size: (int) Number of output neurones of the network. + gpu: (bool) If True the network weights are stored on a CPU, else GPU. + """ + self.dimension = len(input_size) - 1 + self.first_in_channels = input_size[0] + self.layers_dict = self.return_layers_dict() + self.network_normalization = network_normalization + + convolutions = nn.Sequential() + for key, item in convolutions_dict.items(): + convolutional_block = self._define_convolutional_block(item) + convolutions.add_module(key, convolutional_block) + + classifier = nn.Sequential(nn.Flatten(), nn.Dropout(p=dropout)) + + fc, _ = self._fc_dict_design( + n_fcblocks, convolutions_dict, input_size, output_size + ) + for key, item in fc.items(): + n_fc = int(key[2::]) + if n_fc == len(fc) - 1: + fc_block = self._define_fc_layer(item, last_block=True) + else: + fc_block = self._define_fc_layer(item, last_block=False) + classifier.add_module(key, fc_block) + + super().__init__(convolution_layers=convolutions, fc_layers=classifier) + + def _define_convolutional_block(self, conv_dict): + """ + Design a convolutional block from the dictionary conv_dict. + + Args: + conv_dict: (dict) A dictionary with the specifications to build a convolutional block + - n_conv (int) number of convolutional layers in the block + - in_channels (int) number of input channels + - out_channels (int) number of output channels (2 * in_channels or threshold = 512) + - d_reduction (String) "MaxPooling" or "stride" + Returns: + (nn.Module) a list of modules in a nn.Sequential list + """ + in_channels = ( + conv_dict["in_channels"] + if conv_dict["in_channels"] is not None + else self.first_in_channels + ) + out_channels = conv_dict["out_channels"] + + conv_block = [] + for i in range(conv_dict["n_conv"] - 1): + conv_block.append( + self.layers_dict["Conv"]( + in_channels, in_channels, 3, stride=1, padding=1 + ) + ) + conv_block = self._append_normalization_layer(conv_block, in_channels) + conv_block.append(nn.LeakyReLU()) + if conv_dict["d_reduction"] == "MaxPooling": + conv_block.append( + self.layers_dict["Conv"]( + in_channels, out_channels, 3, stride=1, padding=1 + ) + ) + conv_block = self._append_normalization_layer(conv_block, out_channels) + conv_block.append(nn.LeakyReLU()) + conv_block.append(self.layers_dict["Pool"](2, 2)) + elif conv_dict["d_reduction"] == "stride": + conv_block.append( + self.layers_dict["Conv"]( + in_channels, out_channels, 3, stride=2, padding=1 + ) + ) + conv_block = self._append_normalization_layer(conv_block, out_channels) + conv_block.append(nn.LeakyReLU()) + else: + raise ClinicaDLNetworksError( + f"Dimension reduction {conv_dict['d_reduction']} is not supported. Please only include" + "'MaxPooling' or 'stride' in your sampling options." + ) + + return nn.Sequential(*conv_block) + + def _append_normalization_layer(self, conv_block, num_features): + """ + Appends or not a normalization layer to a convolutional block depending on network attributes. + + Args: + conv_block: (list) list of the modules of the convolutional block + num_features: (int) number of features to normalize + Returns: + (list) the updated convolutional block + """ + if self.network_normalization in ["BatchNorm", "InstanceNorm", "GroupNorm"]: + conv_block.append( + self.layers_dict[self.network_normalization](num_features) + ) + elif self.network_normalization is not None: + raise ClinicaDLNetworksError( + f"The network normalization {self.network_normalization} value must be in ['BatchNorm', 'InstanceNorm', 'GroupNorm', None]" + ) + return conv_block + + def return_layers_dict(self): + if self.dimension == 3: + layers = { + "Conv": nn.Conv3d, + "Pool": PadMaxPool3d, + "InstanceNorm": nn.InstanceNorm3d, + "BatchNorm": nn.BatchNorm3d, + "GroupNorm": nn.GroupNorm, + } + elif self.dimension == 2: + layers = { + "Conv": nn.Conv2d, + "Pool": PadMaxPool2d, + "InstanceNorm": nn.InstanceNorm2d, + "BatchNorm": nn.BatchNorm2d, + "GroupNorm": nn.GroupNorm, + } + else: + raise ValueError( + "Cannot construct random network in dimension {self.dimension}." + ) + return layers + + @staticmethod + def _define_fc_layer(fc_dict, last_block=False): + """ + Implement the FC block from the dictionary fc_dict. + + Args: + fc_dict: (dict) A dictionary with the specifications to build a FC block + - in_features (int) number of input neurones + - out_features (int) number of output neurones + last_block: (bool) indicates if the current FC layer is the last one of the network. + Returns: + (nn.Module) a list of modules in a nn.Sequential list + """ + in_features = fc_dict["in_features"] + out_features = fc_dict["out_features"] + + if last_block: + fc_block = [nn.Linear(in_features, out_features)] + else: + fc_block = [nn.Linear(in_features, out_features), nn.LeakyReLU()] + + return nn.Sequential(*fc_block) + + @staticmethod + def recursive_init(layer): + if isinstance(layer, nn.Sequential): + for sub_layer in layer: + RandomArchitecture.recursive_init(sub_layer) + else: + try: + layer.reset_parameters() + except AttributeError: + pass + + @staticmethod + def _fc_dict_design(n_fcblocks, convolutions, initial_shape, n_classes=2): + """ + Sample parameters for a random architecture (FC part). + + Args: + n_fcblocks: (int) number of fully connected blocks in the architecture. + convolutions: (dict) parameters of the convolutional part. + initial_shape: (array_like) shape of the initial input. + n_classes: (int) number of classes in the classification problem. + Returns: + (dict) parameters of the architecture + (list) the shape of the flattened layer + """ + n_conv = len(convolutions) + last_conv = convolutions[f"conv{(len(convolutions) - 1)}"] + out_channels = last_conv["out_channels"] + flattened_shape = np.ceil(np.array(initial_shape) / 2**n_conv) + flattened_shape[0] = out_channels + in_features = np.product(flattened_shape) + + # Sample number of FC layers + ratio = (in_features / n_classes) ** (1 / n_fcblocks) + + # Designing the parameters of each FC block + fc = dict() + for i in range(n_fcblocks): + fc_dict = dict() + out_features = in_features / ratio + fc_dict["in_features"] = int(np.round(in_features)) + fc_dict["out_features"] = int(np.round(out_features)) + + in_features = out_features + fc["FC" + str(i)] = fc_dict + + return fc, flattened_shape diff --git a/clinicadl/nn/networks/ssda.py b/clinicadl/nn/networks/ssda.py new file mode 100644 index 000000000..a87cb33b5 --- /dev/null +++ b/clinicadl/nn/networks/ssda.py @@ -0,0 +1,111 @@ +import numpy as np +import torch +import torch.nn as nn + +from clinicadl.nn.layers import ( + GradientReversal, + get_conv_layer, + get_norm_layer, + get_pool_layer, +) + + +class CNN_SSDA(nn.Module): + """Base class for SSDA CNN.""" + + def __init__( + self, + convolutions, + fc_class_source, + fc_class_target, + fc_domain, + alpha=1.0, + ): + super().__init__() + self.convolutions = convolutions + self.fc_class_source = fc_class_source + self.fc_class_target = fc_class_target + self.fc_domain = fc_domain + self.grad_reverse = GradientReversal(alpha=alpha) + + def forward(self, x): + x = self.convolutions(x) + x_class_source = self.fc_class_source(x) + x_class_target = self.fc_class_target(x) + x_reverse = self.grad_reverse(x) + x_domain = self.fc_domain(x_reverse) + return x_class_source, x_class_target, x_domain + + +class Conv5_FC3_SSDA(CNN_SSDA): + """ + Reduce the 2D or 3D input image to an array of size output_size. + """ + + def __init__(self, input_size, output_size=2, dropout=0.5, alpha=1.0): + dim = len(input_size) - 1 + conv = get_conv_layer(dim) + pool = get_pool_layer("PadMaxPool", dim=dim) + norm = get_norm_layer("BatchNorm", dim=dim) + + convolutions = nn.Sequential( + conv(input_size[0], 8, 3, padding=1), + norm(8), + nn.ReLU(), + pool(2, 2), + conv(8, 16, 3, padding=1), + norm(16), + nn.ReLU(), + pool(2, 2), + conv(16, 32, 3, padding=1), + norm(32), + nn.ReLU(), + pool(2, 2), + conv(32, 64, 3, padding=1), + norm(64), + nn.ReLU(), + pool(2, 2), + conv(64, 128, 3, padding=1), + norm(128), + nn.ReLU(), + pool(2, 2), + ) + + # Compute the size of the first FC layer + input_tensor = torch.zeros(input_size).unsqueeze(0) + output_convolutions = convolutions(input_tensor) + + fc_class_source = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1300), + nn.ReLU(), + nn.Linear(1300, 50), + nn.ReLU(), + nn.Linear(50, output_size), + ) + fc_class_target = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1300), + nn.ReLU(), + nn.Linear(1300, 50), + nn.ReLU(), + nn.Linear(50, output_size), + ) + fc_domain = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1300), + nn.ReLU(), + nn.Linear(1300, 50), + nn.ReLU(), + nn.Linear(50, output_size), + ) + super().__init__( + convolutions, + fc_class_source, + fc_class_target, + fc_domain, + alpha, + ) diff --git a/clinicadl/nn/networks/unet.py b/clinicadl/nn/networks/unet.py new file mode 100644 index 000000000..45850de29 --- /dev/null +++ b/clinicadl/nn/networks/unet.py @@ -0,0 +1,39 @@ +from torch import nn + +from clinicadl.nn.blocks import UNetDown, UNetFinalLayer, UNetUp + + +class UNet(nn.Module): + """ + Generator Unet. + """ + + def __init__(self): + super().__init__() + + self.down1 = UNetDown(1, 64) + self.down2 = UNetDown(64, 128) + self.down3 = UNetDown(128, 256) + self.down4 = UNetDown(256, 512) + self.down5 = UNetDown(512, 512) + + self.up1 = UNetUp(512, 512) + self.up2 = UNetUp(1024, 256) + self.up3 = UNetUp(512, 128) + self.up4 = UNetUp(256, 64) + + self.final = UNetFinalLayer(128, 1) + + def forward(self, x): + d1 = self.down1(x) + d2 = self.down2(d1) + d3 = self.down3(d2) + d4 = self.down4(d3) + d5 = self.down5(d4) + + u1 = self.up1(d5) + u2 = self.up2(u1, d4) + u3 = self.up3(u2, d3) + u4 = self.up4(u3, d2) + + return self.final(u4, d1) diff --git a/clinicadl/nn/networks/vae.py b/clinicadl/nn/networks/vae.py new file mode 100644 index 000000000..9e9b3e72f --- /dev/null +++ b/clinicadl/nn/networks/vae.py @@ -0,0 +1,566 @@ +import torch +import torch.nn as nn + +from clinicadl.nn.blocks import ( + Decoder3D, + Encoder3D, + VAE_Decoder2D, + VAE_Encoder2D, +) +from clinicadl.nn.layers import Unflatten3D +from clinicadl.nn.utils import multiply_list +from clinicadl.utils.enum import BaseEnum + + +class VAE2d(str, BaseEnum): + """VAEs compatible with 2D inputs.""" + + VANILLA_DENSE_VAE = "VanillaDenseVAE" + VANILLA_SPATIAL_VAE = "VanillaSpatialVAE" + + +class VAE3d(str, BaseEnum): + """VAEs compatible with 3D inputs.""" + + VANILLA_DENSE_VAE3D = "VanillaSpatialVAE3D" + VANILLA_SPATIAL_VAE3D = "VanillaDenseVAE3D" + CVAE_3D_FINAL_CONV = "CVAE_3D_final_conv" + CVAE_3D = "CVAE_3D" + CVAE_3D_HALF = "CVAE_3D_half" + + +class ImplementedVAE(str, BaseEnum): + """Implemented VAEs in ClinicaDL.""" + + VANILLA_DENSE_VAE = "VanillaDenseVAE" + VANILLA_SPATIAL_VAE = "VanillaSpatialVAE" + VANILLA_DENSE_VAE3D = "VanillaDenseVAE3D" + VANILLA_SPATIAL_VAE3D = "VanillaSpatialVAE3D" + CVAE_3D_FINAL_CONV = "CVAE_3D_final_conv" + CVAE_3D = "CVAE_3D" + CVAE_3D_HALF = "CVAE_3D_half" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented VAEs are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class VAE(nn.Module): + def __init__(self, encoder, decoder, mu_layers, log_var_layers): + super().__init__() + self.encoder = encoder + self.mu_layers = mu_layers + self.log_var_layers = log_var_layers + self.decoder = decoder + + def encode(self, image): + feature = self.encoder(image) + mu = self.mu_layers(feature) + log_var = self.log_var_layers(feature) + return mu, log_var + + def decode(self, encoded): + reconstructed = self.decoder(encoded) + return reconstructed + + @staticmethod + def _sample(mu, log_var): + std = torch.exp(log_var / 2) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, image): + mu, log_var = self.encode(image) + if self.training: + encoded = self._sample(mu, log_var) + else: + encoded = mu + reconstructed = self.decode(encoded) + return mu, log_var, reconstructed + + +class VanillaDenseVAE(VAE): + """ + This network is a 2D convolutional variational autoencoder with a dense latent space. + + reference: Diederik P Kingma et al., Auto-Encoding Variational Bayes. + https://arxiv.org/abs/1312.6114 + """ + + def __init__(self, input_size, latent_space_size, feature_size): + n_conv = 4 + io_layer_channel = 32 + + encoder = VAE_Encoder2D( + input_shape=input_size, + feature_size=feature_size, + latent_dim=1, + n_conv=n_conv, + first_layer_channels=io_layer_channel, + ) + mu_layers = nn.Linear(feature_size, latent_space_size) + log_var_layers = nn.Linear(feature_size, latent_space_size) + decoder = VAE_Decoder2D( + input_shape=input_size, + latent_size=latent_space_size, + feature_size=feature_size, + latent_dim=1, + n_conv=n_conv, + last_layer_channels=io_layer_channel, + padding=encoder.decoder_padding, + ) + super().__init__(encoder, decoder, mu_layers, log_var_layers) + + +class VanillaDenseVAE3D(VAE): + """ + This network is a 3D convolutional variational autoencoder with a dense latent space. + + reference: Diederik P Kingma et al., Auto-Encoding Variational Bayes. + https://arxiv.org/abs/1312.6114 + """ + + def __init__( + self, + size_reduction_factor, + latent_space_size=256, + feature_size=1024, + n_conv=4, + io_layer_channels=8, + ): + first_layer_channels = io_layer_channels + last_layer_channels = io_layer_channels + # automatically compute padding + decoder_output_padding = [] + + if ( + size_reduction_factor == 2 + ): # TODO : specify that it only works with certain images + self.input_size = [1, 80, 96, 80] + elif size_reduction_factor == 3: + self.input_size = [1, 56, 64, 56] + elif size_reduction_factor == 4: + self.input_size = [1, 40, 48, 40] + elif size_reduction_factor == 5: + self.input_size = [1, 32, 40, 32] + + input_c = self.input_size[0] + input_d = self.input_size[1] + input_h = self.input_size[2] + input_w = self.input_size[3] + d, h, w = input_d, input_h, input_w + + # ENCODER + encoder_layers = [] + # Input Layer + encoder_layers.append(Encoder3D(input_c, first_layer_channels)) + decoder_output_padding.append([d % 2, h % 2, w % 2]) + d, h, w = d // 2, h // 2, w // 2 + # Conv Layers + for i in range(n_conv - 1): + encoder_layers.append( + Encoder3D( + first_layer_channels * 2**i, first_layer_channels * 2 ** (i + 1) + ) + ) + # Construct output paddings + decoder_output_padding.append([d % 2, h % 2, w % 2]) + d, h, w = d // 2, h // 2, w // 2 + # Compute size of the feature space + n_pix = ( + first_layer_channels + * 2 ** (n_conv - 1) + * (input_d // (2**n_conv)) + * (input_h // (2**n_conv)) + * (input_w // (2**n_conv)) + ) + # Flatten + encoder_layers.append(nn.Flatten()) + # Intermediate feature space + if feature_size == 0: + feature_space = n_pix + else: + feature_space = feature_size + encoder_layers.append( + nn.Sequential(nn.Linear(n_pix, feature_space), nn.ReLU()) + ) + encoder = nn.Sequential(*encoder_layers) + + # LATENT SPACE + mu_layers = nn.Linear(feature_space, latent_space_size) + log_var_layers = nn.Linear(feature_space, latent_space_size) + + # DECODER + decoder_layers = [] + # Intermediate feature space + if feature_size == 0: + decoder_layers.append( + nn.Sequential( + nn.Linear(latent_space_size, n_pix), + nn.ReLU(), + ) + ) + else: + decoder_layers.append( + nn.Sequential( + nn.Linear(latent_space_size, feature_size), + nn.ReLU(), + nn.Linear(feature_size, n_pix), + nn.ReLU(), + ) + ) + # Unflatten + decoder_layers.append( + Unflatten3D( + last_layer_channels * 2 ** (n_conv - 1), + input_d // (2**n_conv), + input_h // (2**n_conv), + input_w // (2**n_conv), + ) + ) + # Decoder layers + for i in range(n_conv - 1, 0, -1): + decoder_layers.append( + Decoder3D( + last_layer_channels * 2 ** (i), + last_layer_channels * 2 ** (i - 1), + output_padding=decoder_output_padding[i], + ) + ) + # Output layer + decoder_layers.append( + nn.Sequential( + nn.ConvTranspose3d( + last_layer_channels, + input_c, + 4, + stride=2, + padding=1, + output_padding=decoder_output_padding[0], + bias=False, + ), + nn.Sigmoid(), + ) + ) + decoder = nn.Sequential(*decoder_layers) + + super().__init__(encoder, decoder, mu_layers, log_var_layers) + + +class VanillaSpatialVAE(VAE): + """ + This network is a 2D convolutional variational autoencoder with a spatial latent space. + + reference: Diederik P Kingma et al., Auto-Encoding Variational Bayes. + https://arxiv.org/abs/1312.6114 + """ + + def __init__( + self, + input_size, + ): + feature_channels = 64 + latent_channels = 1 + n_conv = 4 + io_layer_channel = 32 + + encoder = VAE_Encoder2D( + input_shape=input_size, + feature_size=feature_channels, + latent_dim=2, + n_conv=n_conv, + first_layer_channels=io_layer_channel, + ) + mu_layers = nn.Conv2d( + feature_channels, latent_channels, 3, stride=1, padding=1, bias=False + ) + log_var_layers = nn.Conv2d( + feature_channels, latent_channels, 3, stride=1, padding=1, bias=False + ) + decoder = VAE_Decoder2D( + input_shape=input_size, + latent_size=latent_channels, + feature_size=feature_channels, + latent_dim=2, + n_conv=n_conv, + last_layer_channels=io_layer_channel, + padding=encoder.decoder_padding, + ) + super().__init__(encoder, decoder, mu_layers, log_var_layers) + + +class VanillaSpatialVAE3D(VAE): + """ + This network is a 3D convolutional variational autoencoder with a spatial latent space. + + reference: Diederik P Kingma et al., Auto-Encoding Variational Bayes. + https://arxiv.org/abs/1312.6114 + """ + + def __init__(self, input_size): + n_conv = 4 + first_layer_channels = 32 + last_layer_channels = 32 + feature_channels = 512 + latent_channels = 1 + decoder_output_padding = [ + [1, 0, 0], + [0, 0, 0], + [0, 0, 1], + ] + input_c = input_size[0] + + encoder_layers = [] + encoder_layers.append(Encoder3D(input_c, first_layer_channels)) + for i in range(n_conv - 1): + encoder_layers.append( + Encoder3D( + first_layer_channels * 2**i, first_layer_channels * 2 ** (i + 1) + ) + ) + encoder_layers.append( + nn.Sequential( + nn.Conv3d( + first_layer_channels * 2 ** (n_conv - 1), + feature_channels, + 4, + stride=2, + padding=1, + bias=False, + ), + nn.ReLU(), + ) + ) + encoder = nn.Sequential(*encoder_layers) + mu_layers = nn.Conv3d( + feature_channels, latent_channels, 3, stride=1, padding=1, bias=False + ) + log_var_layers = nn.Conv3d( + feature_channels, latent_channels, 3, stride=1, padding=1, bias=False + ) + decoder_layers = [] + decoder_layers.append( + nn.Sequential( + nn.ConvTranspose3d( + latent_channels, + feature_channels, + 3, + stride=1, + padding=1, + bias=False, + ), + nn.ReLU(), + nn.ConvTranspose3d( + feature_channels, + last_layer_channels * 2 ** (n_conv - 1), + 4, + stride=2, + padding=1, + output_padding=[0, 1, 1], + bias=False, + ), + nn.ReLU(), + ) + ) + for i in range(n_conv - 1, 0, -1): + decoder_layers.append( + Decoder3D( + last_layer_channels * 2 ** (i), + last_layer_channels * 2 ** (i - 1), + output_padding=decoder_output_padding[i], + ) + ) + decoder_layers.append( + nn.Sequential( + nn.ConvTranspose3d( + last_layer_channels, + input_c, + 4, + stride=2, + padding=1, + output_padding=[1, 0, 1], + bias=False, + ), + nn.Sigmoid(), + ) + ) + decoder = nn.Sequential(*decoder_layers) + super().__init__(encoder, decoder, mu_layers, log_var_layers) + + +class CVAE_3D_final_conv(VAE): + """ + This is the convolutional autoencoder whose main objective is to project the MRI into a smaller space + with the sole criterion of correctly reconstructing the data. Nothing longitudinal here. + fc = final layer conv + """ + + def __init__(self, size_reduction_factor, latent_space_size): + n_conv = 3 + + if size_reduction_factor == 2: + self.input_size = [1, 80, 96, 80] + elif size_reduction_factor == 3: + self.input_size = [1, 56, 64, 56] + elif size_reduction_factor == 4: + self.input_size = [1, 40, 48, 40] + elif size_reduction_factor == 5: + self.input_size = [1, 32, 40, 32] + feature_size = int(multiply_list(self.input_size[1:], 2**n_conv) * 128) + + encoder = nn.Sequential( + nn.Conv3d(1, 32, 3, stride=2, padding=1), + nn.InstanceNorm3d(32), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv3d(32, 64, 3, stride=2, padding=1), + nn.InstanceNorm3d(64), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv3d(64, 128, 3, stride=2, padding=1), + nn.InstanceNorm3d(128), + nn.LeakyReLU(negative_slope=0.2), + nn.Flatten(start_dim=1), + ) + mu_layers = nn.Sequential( + nn.Linear(feature_size, latent_space_size), + nn.Tanh(), + ) + log_var_layers = nn.Linear(feature_size, latent_space_size) + decoder = nn.Sequential( + nn.Linear(latent_space_size, 2 * feature_size), + nn.LeakyReLU(), + nn.Unflatten( + dim=1, + unflattened_size=( + 256, + self.input_size[1] // 2**n_conv, + self.input_size[2] // 2**n_conv, + self.input_size[3] // 2**n_conv, + ), + ), + nn.ConvTranspose3d(256, 128, 3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(128), + nn.LeakyReLU(), + nn.ConvTranspose3d(128, 64, 3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(64), + nn.LeakyReLU(), + nn.ConvTranspose3d(64, 1, 3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(1), + nn.LeakyReLU(), + nn.Conv3d(1, 1, 3, stride=1, padding=1), + nn.Sigmoid(), + ) + super().__init__(encoder, decoder, mu_layers, log_var_layers) + + +class CVAE_3D(VAE): + """ + This is the convolutional autoencoder whose main objective is to project the MRI into a smaller space + with the sole criterion of correctly reconstructing the data. Nothing longitudinal here. + """ + + def __init__(self, latent_space_size): # TODO : only work with 1-channel input + encoder = nn.Sequential( + nn.Conv3d(1, 32, 3, stride=2, padding=1), + nn.BatchNorm3d(32), + nn.ReLU(), + nn.Conv3d(32, 64, 3, stride=2, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.Conv3d(64, 128, 3, stride=2, padding=1), + nn.BatchNorm3d(128), + nn.ReLU(), + nn.Flatten(start_dim=1), + ) + mu_layers = nn.Sequential( + nn.Linear(1683968, latent_space_size), + nn.Tanh(), + ) + log_var_layers = nn.Linear(1683968, latent_space_size) + decoder = nn.Sequential( + nn.Linear(latent_space_size, 3367936), + nn.ReLU(), + nn.Unflatten( + dim=1, + unflattened_size=( + 256, + 22, + 26, + 23, + ), + ), + nn.ConvTranspose3d( + 256, 128, 3, stride=2, padding=1, output_padding=[0, 1, 0] + ), + nn.BatchNorm3d(128), + nn.ReLU(), + nn.ConvTranspose3d( + 128, 64, 3, stride=2, padding=1, output_padding=[0, 1, 1] + ), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.ConvTranspose3d(64, 1, 3, stride=2, padding=1, output_padding=[0, 1, 0]), + nn.ReLU(), + ) + super().__init__(encoder, decoder, mu_layers, log_var_layers) + + +class CVAE_3D_half(VAE): + """ + This is the convolutional autoencoder whose main objective is to project the MRI into a smaller space + with the sole criterion of correctly reconstructing the data. Nothing longitudinal here. + """ + + def __init__(self, size_reduction_factor, latent_space_size): + n_conv = 3 + if size_reduction_factor == 2: + self.input_size = [1, 80, 96, 80] + elif size_reduction_factor == 3: + self.input_size = [1, 56, 64, 56] + elif size_reduction_factor == 4: + self.input_size = [1, 40, 48, 40] + elif size_reduction_factor == 5: + self.input_size = [1, 32, 40, 32] + feature_size = int(multiply_list(self.input_size[1:], 2**n_conv) * 128) + + encoder = nn.Sequential( + nn.Conv3d(1, 32, 3, stride=2, padding=1), + nn.InstanceNorm3d(32), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv3d(32, 64, 3, stride=2, padding=1), + nn.InstanceNorm3d(64), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv3d(64, 128, 3, stride=2, padding=1), + nn.InstanceNorm3d(128), + nn.LeakyReLU(negative_slope=0.2), + nn.Flatten(start_dim=1), + ) + mu_layers = nn.Sequential( + nn.Linear(feature_size, latent_space_size), + nn.Tanh(), + ) + log_var_layers = nn.Linear(feature_size, latent_space_size) + decoder = nn.Sequential( + nn.Linear(latent_space_size, 2 * feature_size), + nn.ReLU(), + nn.Unflatten( + dim=1, + unflattened_size=( + 256, + self.input_size[1] // 2**n_conv, + self.input_size[2] // 2**n_conv, + self.input_size[3] // 2**n_conv, + ), + ), + nn.ConvTranspose3d(256, 128, 3, stride=2, padding=1, output_padding=1), + nn.BatchNorm3d(128), + nn.ReLU(), + nn.ConvTranspose3d(128, 64, 3, stride=2, padding=1, output_padding=1), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.ConvTranspose3d(64, 1, 3, stride=2, padding=1, output_padding=1), + nn.Sigmoid(), + ) + super().__init__(encoder, decoder, mu_layers, log_var_layers) diff --git a/clinicadl/nn/utils.py b/clinicadl/nn/utils.py new file mode 100644 index 000000000..dc3afd71c --- /dev/null +++ b/clinicadl/nn/utils.py @@ -0,0 +1,74 @@ +from collections.abc import Iterable +from typing import Any, Callable, Dict, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.modules.module import _addindent + + +def torch_summarize(model, show_weights=True, show_parameters=True): + """Summarizes torch model by showing trainable parameters and weights.""" + + tmpstr = model.__class__.__name__ + " (\n" + for key, module in model._modules.items(): + # if it contains layers let call it recursively to get params and weights + if type(module) in [ + torch.nn.modules.container.Container, + torch.nn.modules.container.Sequential, + ]: + modstr = torch_summarize(module) + else: + modstr = module.__repr__() + modstr = _addindent(modstr, 2) + + params = sum([np.prod(p.size()) for p in module.parameters()]) + weights = tuple([tuple(p.size()) for p in module.parameters()]) + + tmpstr += " (" + key + "): " + modstr + if show_weights: + tmpstr += ", weights={}".format(weights) + if show_parameters: + tmpstr += ", parameters={}".format(params) + tmpstr += "\n" + + tmpstr = tmpstr + ")" + return tmpstr + + +def multiply_list(L, factor): + product = 1 + for x in L: + product = product * x / factor + return product + + +def compute_output_size( + input_size: Union[torch.Size, Tuple], layer: nn.Module +) -> Tuple: + """ + Computes the output size of a layer. + + Parameters + ---------- + input_size : Union[torch.Size, Tuple] + The unbatched input size (i.e. C, H, W(, D)) + layer : nn.Module + The layer. + + Returns + ------- + Tuple + The unbatched size of the output. + """ + input_ = torch.randn(input_size).unsqueeze(0) + if isinstance(layer, nn.MaxUnpool3d) or isinstance(layer, nn.MaxUnpool2d): + indices = torch.zeros_like(input_, dtype=int) + print(indices) + output = layer(input_, indices) + else: + output = layer(input_) + if isinstance(layer, nn.MaxPool3d) or isinstance(layer, nn.MaxPool2d): + if layer.return_indices: + output = output[0] + return tuple(output.shape[1:]) diff --git a/clinicadl/utils/enum.py b/clinicadl/utils/enum.py index 47617ef6b..3e9031534 100644 --- a/clinicadl/utils/enum.py +++ b/clinicadl/utils/enum.py @@ -1,6 +1,17 @@ from enum import Enum +class BaseEnum(Enum): + """Base Enum object that will print valid inputs if the value passed is not valid.""" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}. Valid ones are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + class Task(str, Enum): """Tasks that can be performed in ClinicaDL.""" diff --git a/tests/unittests/nn/__init__.py b/tests/unittests/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/nn/blocks/__init__.py b/tests/unittests/nn/blocks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/nn/blocks/test_decoder.py b/tests/unittests/nn/blocks/test_decoder.py new file mode 100644 index 000000000..01bf7aef1 --- /dev/null +++ b/tests/unittests/nn/blocks/test_decoder.py @@ -0,0 +1,59 @@ +import pytest +import torch + +import clinicadl.nn.blocks.decoder as decoder + + +@pytest.fixture +def input_2d(): + return torch.randn(2, 1, 10, 10) + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 1, 10, 10, 10) + + +@pytest.fixture +def latent_vector(): + return torch.randn(2, 3) + + +@pytest.fixture(params=["latent_vector", "input_2d"]) +def to_decode(request): + return request.getfixturevalue(request.param) + + +def test_decoder2d(input_2d): + network = decoder.Decoder2D( + input_channels=input_2d.shape[1], output_channels=(input_2d.shape[1] + 3) + ) + output_2d = network(input_2d) + assert output_2d.shape[1] == input_2d.shape[1] + 3 + assert len(output_2d.shape) == 4 + + +def test_vae_decoder2d(to_decode): + latent_dim = 1 if len(to_decode.shape) == 2 else 2 + + network = decoder.VAE_Decoder2D( + input_shape=(1, 5, 5), + latent_size=to_decode.shape[1], + n_conv=1, + last_layer_channels=2, + latent_dim=latent_dim, + feature_size=4, + ) + output_2d = network(to_decode) + assert len(output_2d.shape) == 4 + assert output_2d.shape[0] == 2 + assert output_2d.shape[1] == 1 + + +def test_decoder3d(input_3d): + network = decoder.Decoder3D( + input_channels=input_3d.shape[1], output_channels=(input_3d.shape[1] + 3) + ) + output_3d = network(input_3d) + assert output_3d.shape[1] == input_3d.shape[1] + 3 + assert len(output_3d.shape) == 5 diff --git a/tests/unittests/nn/blocks/test_encoder.py b/tests/unittests/nn/blocks/test_encoder.py new file mode 100644 index 000000000..dcb676f96 --- /dev/null +++ b/tests/unittests/nn/blocks/test_encoder.py @@ -0,0 +1,46 @@ +import pytest +import torch + +import clinicadl.nn.blocks.encoder as encoder + + +@pytest.fixture +def input_2d(): + return torch.randn(2, 1, 10, 10) + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 1, 10, 10, 10) + + +def test_encoder2d(input_2d): + network = encoder.Encoder2D( + input_channels=input_2d.shape[1], output_channels=(input_2d.shape[1] + 3) + ) + output_2d = network(input_2d) + assert output_2d.shape[1] == input_2d.shape[1] + 3 + assert len(output_2d.shape) == 4 + + +@pytest.mark.parametrize("latent_dim", [1, 2]) +def test_vae_encoder2d(latent_dim, input_2d): + network = encoder.VAE_Encoder2D( + input_shape=(1, 10, 10), + n_conv=1, + first_layer_channels=4, + latent_dim=latent_dim, + feature_size=4, + ) + output = network(input_2d) + assert output.shape[0] == 2 + assert len(output.shape) == 2 if latent_dim == 1 else 4 + + +def test_encoder3d(input_3d): + network = encoder.Encoder3D( + input_channels=input_3d.shape[1], output_channels=(input_3d.shape[1] + 3) + ) + output_3d = network(input_3d) + assert output_3d.shape[1] == input_3d.shape[1] + 3 + assert len(output_3d.shape) == 5 diff --git a/tests/unittests/nn/blocks/test_residual.py b/tests/unittests/nn/blocks/test_residual.py new file mode 100644 index 000000000..302051ee3 --- /dev/null +++ b/tests/unittests/nn/blocks/test_residual.py @@ -0,0 +1,9 @@ +import torch + +from clinicadl.nn.blocks import ResBlock + + +def test_resblock(): + input_ = torch.randn((2, 4, 5, 5, 5)) + resblock = ResBlock(block_number=1, input_size=4) + assert resblock(input_).shape == torch.Size((2, 8, 5, 5, 5)) diff --git a/tests/unittests/nn/blocks/test_se.py b/tests/unittests/nn/blocks/test_se.py new file mode 100644 index 000000000..2444bcc3a --- /dev/null +++ b/tests/unittests/nn/blocks/test_se.py @@ -0,0 +1,33 @@ +import pytest +import torch + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 6, 10, 10, 10) + + +def test_SE_Block(input_3d): + from clinicadl.nn.blocks import SE_Block + + layer = SE_Block(num_channels=input_3d.shape[1], ratio_channel=4) + out = layer(input_3d) + assert out.shape == input_3d.shape + + +def test_ResBlock_SE(input_3d): + from clinicadl.nn.blocks import ResBlock_SE + + layer = ResBlock_SE( + num_channels=input_3d.shape[1], + block_number=1, + input_size=input_3d.shape[1], + ratio_channel=4, + ) + out = layer(input_3d) + assert out.shape[:2] == torch.Size( + ( + input_3d.shape[0], + 2**3, + ) + ) diff --git a/tests/unittests/nn/blocks/test_unet.py b/tests/unittests/nn/blocks/test_unet.py new file mode 100644 index 000000000..4e7170d77 --- /dev/null +++ b/tests/unittests/nn/blocks/test_unet.py @@ -0,0 +1,36 @@ +import pytest +import torch + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 4, 10, 10, 10) + + +@pytest.fixture +def skip_input(): + return torch.randn(2, 4, 10, 10, 10) + + +def test_UNetDown(input_3d): + from clinicadl.nn.blocks import UNetDown + + layer = UNetDown(in_size=input_3d.shape[1], out_size=8) + out = layer(input_3d) + assert out.shape[:2] == torch.Size((input_3d.shape[0], 8)) + + +def test_UNetUp(input_3d, skip_input): + from clinicadl.nn.blocks import UNetUp + + layer = UNetUp(in_size=input_3d.shape[1] * 2, out_size=2) + out = layer(input_3d, skip_input=skip_input) + assert out.shape[:2] == torch.Size((input_3d.shape[0], 2)) + + +def test_UNetFinalLayer(input_3d, skip_input): + from clinicadl.nn.blocks import UNetFinalLayer + + layer = UNetFinalLayer(in_size=input_3d.shape[1] * 2, out_size=2) + out = layer(input_3d, skip_input=skip_input) + assert out.shape[:2] == torch.Size((input_3d.shape[0], 2)) diff --git a/tests/unittests/nn/layers/__init__.py b/tests/unittests/nn/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/nn/layers/factory/__init__.py b/tests/unittests/nn/layers/factory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/nn/layers/factory/test_factories.py b/tests/unittests/nn/layers/factory/test_factories.py new file mode 100644 index 000000000..7036cc724 --- /dev/null +++ b/tests/unittests/nn/layers/factory/test_factories.py @@ -0,0 +1,27 @@ +import pytest +import torch.nn as nn + + +def test_get_conv_layer(): + from clinicadl.nn.layers.factory import get_conv_layer + + assert get_conv_layer(2) == nn.Conv2d + assert get_conv_layer(3) == nn.Conv3d + with pytest.raises(AssertionError): + get_conv_layer(1) + + +def test_get_norm_layer(): + from clinicadl.nn.layers.factory import get_norm_layer + + assert get_norm_layer("InstanceNorm", 2) == nn.InstanceNorm2d + assert get_norm_layer("BatchNorm", 3) == nn.BatchNorm3d + assert get_norm_layer("GroupNorm", 3) == nn.GroupNorm + + +def test_get_pool_layer(): + from clinicadl.nn.layers import PadMaxPool3d + from clinicadl.nn.layers.factory import get_pool_layer + + assert get_pool_layer("MaxPool", 2) == nn.MaxPool2d + assert get_pool_layer("PadMaxPool", 3) == PadMaxPool3d diff --git a/tests/unittests/nn/layers/test_layers.py b/tests/unittests/nn/layers/test_layers.py new file mode 100644 index 000000000..e07eb1cf6 --- /dev/null +++ b/tests/unittests/nn/layers/test_layers.py @@ -0,0 +1,101 @@ +import pytest +import torch + +import clinicadl.nn.layers as layers + + +@pytest.fixture +def input_2d(): + return torch.randn(2, 1, 5, 5) + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 1, 5, 5, 5) + + +def test_pool_layers(input_2d, input_3d): + output_3d = layers.PadMaxPool3d(kernel_size=2, stride=1)(input_3d) + output_2d = layers.PadMaxPool2d(kernel_size=2, stride=1)(input_2d) + + assert len(output_3d.shape) == 5 # TODO : test more precisely and test padding + assert output_3d.shape[0] == 2 + assert len(output_2d.shape) == 4 + assert output_2d.shape[0] == 2 + + +def test_unpool_layers(): # TODO : test padding + import torch.nn as nn + + pool = nn.MaxPool2d(2, stride=2, return_indices=True) + unpool = layers.CropMaxUnpool2d(2, stride=2) + input_ = torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ] + ) + excpected_output = torch.tensor( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 0.0, 8.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 14.0, 0.0, 16.0], + ] + ] + ] + ) + output, indices = pool(input_) + assert (unpool(output, indices) == excpected_output).all() + + pool = nn.MaxPool3d(2, stride=1, return_indices=True) + unpool = layers.CropMaxUnpool3d(2, stride=1) + input_ = torch.tensor([[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]]) + excpected_output = torch.tensor( + [[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 8.0]]]] + ) + output, indices = pool(input_) + assert (unpool(output, indices) == excpected_output).all() + + +def test_unflatten_layers(): + flattened_2d = torch.randn(2, 1 * 5 * 4) + flattened_3d = torch.randn(2, 1 * 5 * 4 * 3) + + output_3d = layers.Unflatten3D(channel=1, height=5, width=4, depth=3)(flattened_3d) + output_2d = layers.Unflatten2D(channel=1, height=5, width=4)(flattened_2d) + + assert output_3d.shape == torch.Size((2, 1, 5, 4, 3)) + assert output_2d.shape == torch.Size((2, 1, 5, 4)) + + +def test_reshape_layers(input_2d): + reshape = layers.Reshape((2, 1, 25)) + assert reshape(input_2d).shape == torch.Size((2, 1, 25)) + + +def test_gradient_reversal(input_3d): + from copy import deepcopy + + import torch.nn as nn + + input_ = torch.randn(2, 5) + ref_ = torch.randn(2, 3) + layer = nn.Linear(5, 3) + reversed_layer = nn.Sequential(deepcopy(layer), layers.GradientReversal(alpha=2.0)) + criterion = torch.nn.MSELoss() + + criterion(layer(input_), ref_).backward() + criterion(reversed_layer(input_), ref_).backward() + assert all( + (p2.grad == -2.0 * p1.grad).all() + for p1, p2 in zip(layer.parameters(), reversed_layer.parameters()) + ) diff --git a/tests/unittests/nn/networks/__init__.py b/tests/unittests/nn/networks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/nn/networks/factory/__init__.py b/tests/unittests/nn/networks/factory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/nn/networks/factory/test_ae_factory.py b/tests/unittests/nn/networks/factory/test_ae_factory.py new file mode 100644 index 000000000..a4fe1a762 --- /dev/null +++ b/tests/unittests/nn/networks/factory/test_ae_factory.py @@ -0,0 +1,68 @@ +import pytest +import torch +import torch.nn as nn + +from clinicadl.nn.layers import ( + PadMaxPool2d, + PadMaxPool3d, +) + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 4, 10, 10, 10) + + +@pytest.fixture +def input_2d(): + return torch.randn(2, 4, 10, 10) + + +@pytest.fixture +def cnn3d(): + class CNN(nn.Module): + def __init__(self, input_size): + super().__init__() + self.convolutions = nn.Sequential( + nn.Conv3d(in_channels=input_size[0], out_channels=4, kernel_size=3), + nn.BatchNorm3d(num_features=4), + nn.LeakyReLU(), + PadMaxPool3d(kernel_size=2, stride=1, return_indices=False), + ) + self.fc = nn.Sequential( + nn.Flatten(), + nn.Linear(42, 2), + ) + + return CNN + + +@pytest.fixture +def cnn2d(): + class CNN(nn.Module): + def __init__(self, input_size): + super().__init__() + self.convolutions = nn.Sequential( + nn.Conv2d(in_channels=input_size[0], out_channels=4, kernel_size=3), + nn.BatchNorm2d(num_features=4), + nn.LeakyReLU(), + PadMaxPool2d(kernel_size=2, stride=1, return_indices=False), + ) + self.fc = nn.Sequential( + nn.Flatten(), + nn.Linear(42, 2), # should not raise an error + ) + + return CNN + + +@pytest.mark.parametrize("input, cnn", [("input_3d", "cnn3d"), ("input_2d", "cnn2d")]) +def test_autoencoder_from_cnn(input, cnn, request): + from clinicadl.nn.networks.ae import AE + from clinicadl.nn.networks.factory import autoencoder_from_cnn + + input_ = request.getfixturevalue(input) + cnn = request.getfixturevalue(cnn)(input_size=input_.shape[1:]) + encoder, decoder = autoencoder_from_cnn(cnn) + autoencoder = AE(encoder, decoder) + assert autoencoder(input_).shape == input_.shape diff --git a/tests/unittests/nn/networks/factory/test_resnet_factory.py b/tests/unittests/nn/networks/factory/test_resnet_factory.py new file mode 100644 index 000000000..1468d37ad --- /dev/null +++ b/tests/unittests/nn/networks/factory/test_resnet_factory.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + + +def test_ResNetDesigner(): + from torchvision.models.resnet import BasicBlock + + from clinicadl.nn.networks.factory import ResNetDesigner + + input_ = torch.randn(2, 3, 100, 100) + + class Model(nn.Module): + def __init__(self): + super().__init__() + model = ResNetDesigner( + input_size=input_.shape[1:], + block=BasicBlock, + layers=[1, 2, 3, 4], + num_classes=2, + ) + self.convolutions = nn.Sequential( + model.conv1, + model.bn1, + model.relu, + model.maxpool, + model.layer1, + model.layer2, + model.layer3, + model.layer4, + model.avgpool, + ) + self.fc = nn.Sequential( + nn.Flatten(), + model.fc, + ) + + def forward(self, x): + return self.fc(self.convolutions(x)) + + model = Model() + + assert model(input_).shape == torch.Size([2, 2]) + + +def test_ResNetDesigner3D(): + from clinicadl.nn.networks.factory import ResNetDesigner3D + + input_ = torch.randn(2, 3, 100, 100, 100) + + class Model(nn.Module): + def __init__(self): + super().__init__() + model = ResNetDesigner3D( + input_size=input_.shape[1:], output_size=2, dropout=0.5 + ) + self.convolutions = nn.Sequential( + model.layer0, + model.layer1, + model.layer2, + model.layer3, + model.layer4, + ) + self.fc = model.fc + + def forward(self, x): + return self.fc(self.convolutions(x)) + + model = Model() + + assert model(input_).shape == torch.Size([2, 2]) diff --git a/tests/unittests/nn/networks/factory/test_secnn_factory.py b/tests/unittests/nn/networks/factory/test_secnn_factory.py new file mode 100644 index 000000000..96be92620 --- /dev/null +++ b/tests/unittests/nn/networks/factory/test_secnn_factory.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + + +def test_SECNNDesigner3D(): + from clinicadl.nn.networks.factory import SECNNDesigner3D + + input_ = torch.randn(2, 3, 100, 100, 100) + + class Model(nn.Module): + def __init__(self): + super().__init__() + model = SECNNDesigner3D( + input_size=input_.shape[1:], output_size=2, dropout=0.5 + ) + self.convolutions = nn.Sequential( + model.layer0, + model.layer1, + model.layer2, + model.layer3, + model.layer4, + ) + self.fc = model.fc + + def forward(self, x): + return self.fc(self.convolutions(x)) + + model = Model() + + assert model(input_).shape == torch.Size([2, 2]) diff --git a/tests/unittests/nn/networks/test_ae.py b/tests/unittests/nn/networks/test_ae.py new file mode 100644 index 000000000..9c6152d35 --- /dev/null +++ b/tests/unittests/nn/networks/test_ae.py @@ -0,0 +1,25 @@ +import pytest +import torch + +import clinicadl.nn.networks.ae as ae + + +@pytest.mark.parametrize("network", [net.value for net in ae.AE2d]) +def test_2d_ae(network): + input_ = torch.randn(2, 3, 100, 100) + network = getattr(ae, network)(input_size=input_.shape[1:], dropout=0.5) + output = network(input_) + assert output.shape == input_.shape + + +@pytest.mark.parametrize("network", [net.value for net in ae.AE3d]) +def test_3d_ae(network): + input_ = torch.randn(2, 1, 49, 49, 49) + if network == "CAE_half": + network = getattr(ae, network)( + input_size=input_.shape[1:], latent_space_size=10 + ) + else: + network = getattr(ae, network)(input_size=input_.shape[1:], dropout=0.5) + output = network(input_) + assert output.shape == input_.shape diff --git a/tests/unittests/nn/networks/test_cnn.py b/tests/unittests/nn/networks/test_cnn.py new file mode 100644 index 000000000..3f6a0cb87 --- /dev/null +++ b/tests/unittests/nn/networks/test_cnn.py @@ -0,0 +1,32 @@ +import pytest +import torch + +import clinicadl.nn.networks.cnn as cnn + + +@pytest.fixture +def input_2d(): + return torch.randn(2, 3, 100, 100) + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 1, 100, 100, 100) + + +@pytest.mark.parametrize("network", [net.value for net in cnn.CNN2d]) +def test_2d_cnn(network, input_2d): + network = getattr(cnn, network)( + input_size=input_2d.shape[1:], output_size=3, dropout=0.5 + ) + output_2d = network(input_2d) + assert output_2d.shape == (2, 3) + + +@pytest.mark.parametrize("network", [net.value for net in cnn.CNN3d]) +def test_3d_cnn(network, input_3d): + network = getattr(cnn, network)( + input_size=input_3d.shape[1:], output_size=1, dropout=0.5 + ) + output_2d = network(input_3d) + assert output_2d.shape == (2, 1) diff --git a/tests/unittests/nn/networks/test_ssda.py b/tests/unittests/nn/networks/test_ssda.py new file mode 100644 index 000000000..06da85ff2 --- /dev/null +++ b/tests/unittests/nn/networks/test_ssda.py @@ -0,0 +1,11 @@ +import torch + +from clinicadl.nn.networks.ssda import Conv5_FC3_SSDA + + +def test_UNet(): + input_ = torch.randn(2, 1, 64, 63, 62) + network = Conv5_FC3_SSDA(input_size=(1, 64, 63, 62), output_size=3) + output = network(input_) + for out in output: + assert out.shape == torch.Size((2, 3)) diff --git a/tests/unittests/nn/networks/test_unet.py b/tests/unittests/nn/networks/test_unet.py new file mode 100644 index 000000000..ba0408cdb --- /dev/null +++ b/tests/unittests/nn/networks/test_unet.py @@ -0,0 +1,9 @@ +import torch + +from clinicadl.nn.networks.unet import UNet + + +def test_UNet(): + input_ = torch.randn(2, 1, 64, 64, 64) # TODO : specify the size that works + network = UNet() + assert network(input_).shape == input_.shape diff --git a/tests/unittests/nn/networks/test_vae.py b/tests/unittests/nn/networks/test_vae.py new file mode 100644 index 000000000..308a2f185 --- /dev/null +++ b/tests/unittests/nn/networks/test_vae.py @@ -0,0 +1,87 @@ +import pytest +import torch + +import clinicadl.nn.networks.vae as vae + + +@pytest.fixture +def input_2d(): + return torch.randn(2, 3, 100, 100) + + +@pytest.fixture +def input_3d(): + return torch.randn(2, 1, 50, 50, 50) + + +@pytest.mark.parametrize( + "input_,network,latent_space_size", + [ + ( + torch.randn(2, 3, 100, 100), + vae.VanillaDenseVAE( + input_size=(3, 100, 100), latent_space_size=10, feature_size=100 + ), + 10, + ), + ( + torch.randn(2, 1, 80, 96, 80), + vae.VanillaDenseVAE3D( + size_reduction_factor=2, + latent_space_size=10, + feature_size=100, + ), + 10, + ), + # ( + # torch.randn(2, 1, 50, 50, 50), # TODO : only work with certain size + # vae.CVAE_3D( + # input_size=(3, 50, 50, 50), + # latent_space_size=10, + # ), + # 10, + # ), + ( + torch.randn(2, 1, 56, 64, 56), + vae.CVAE_3D_final_conv( + size_reduction_factor=3, + latent_space_size=10, + ), + 10, + ), + ( + torch.randn(2, 1, 32, 40, 32), + vae.CVAE_3D_half( + size_reduction_factor=5, + latent_space_size=10, + ), + 10, + ), + ], +) +def test_DenseVAEs(input_, network, latent_space_size): + output = network(input_) + + assert output[0].shape == torch.Size((input_.shape[0], latent_space_size)) + assert output[1].shape == torch.Size((input_.shape[0], latent_space_size)) + assert output[2].shape == input_.shape + + +@pytest.mark.parametrize( + "input_,network", + [ + ( + torch.randn(2, 3, 100, 100), + vae.VanillaSpatialVAE(input_size=(3, 100, 100)), + ), + # (torch.randn(2, 3, 100, 100, 100), vae.VanillaSpatialVAE3D(input_size=(3, 100, 100, 100))), # TODO : output doesn't have the same size + ], +) +def test_SpatialVAEs(input_, network): + output = network(input_) + + assert output[0].shape[:2] == torch.Size((input_.shape[0], 1)) + assert len(output[0].shape) == len(input_.shape) + assert output[1].shape[:2] == torch.Size((input_.shape[0], 1)) + assert len(output[0].shape) == len(input_.shape) + assert output[2].shape == input_.shape diff --git a/tests/unittests/nn/test_utils.py b/tests/unittests/nn/test_utils.py new file mode 100644 index 000000000..bcd379613 --- /dev/null +++ b/tests/unittests/nn/test_utils.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + + +def test_compute_output_size(): + from clinicadl.nn.utils import compute_output_size + + input_2d = torch.randn(3, 2, 100, 100) + input_3d = torch.randn(3, 1, 100, 100, 100) + indices_2d = torch.randint(0, 100, size=(3, 2, 100, 100)) + + conv3d = nn.Conv3d( + in_channels=1, + out_channels=1, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + dilation=3, + ) + max_pool3d = nn.MaxPool3d(kernel_size=(9, 8, 7), stride=1, padding=3, dilation=2) + conv_transpose2d = nn.ConvTranspose2d( + in_channels=2, + out_channels=1, + kernel_size=7, + stride=(4, 3), + padding=0, + dilation=(2, 1), + output_padding=1, + ) + max_unpool2d = nn.MaxUnpool2d(kernel_size=7, stride=(2, 1), padding=(1, 1)) + sequential = nn.Sequential( + conv3d, nn.Dropout(p=0.5), nn.BatchNorm3d(num_features=1), max_pool3d + ) + + assert compute_output_size(input_3d.shape[1:], conv3d) == tuple( + conv3d(input_3d).shape[1:] + ) + assert compute_output_size(input_3d.shape[1:], max_pool3d) == tuple( + max_pool3d(input_3d).shape[1:] + ) + assert compute_output_size(input_2d.shape[1:], conv_transpose2d) == tuple( + conv_transpose2d(input_2d).shape[1:] + ) + assert compute_output_size(input_2d.shape[1:], max_unpool2d) == tuple( + max_unpool2d(input_2d, indices_2d).shape[1:] + ) + assert compute_output_size(tuple(input_3d.shape[1:]), sequential) == tuple( + sequential(input_3d).shape[1:] + )