Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic neural networks #660

Merged
merged 36 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1fb6d19
utils
thibaultdvx Oct 1, 2024
586cc9c
test utils
thibaultdvx Oct 1, 2024
0c29365
layers
thibaultdvx Oct 1, 2024
50fdf12
test layers
thibaultdvx Oct 1, 2024
b592dc4
mlp
thibaultdvx Oct 1, 2024
70b45f7
fully convolutional
thibaultdvx Oct 1, 2024
60b8304
cnn
thibaultdvx Oct 1, 2024
5865ec2
generator
thibaultdvx Oct 1, 2024
683570f
autoencoder
thibaultdvx Oct 1, 2024
348754f
vae
thibaultdvx Oct 1, 2024
402f942
put enum back
thibaultdvx Oct 1, 2024
0948792
change name fcn to conv
thibaultdvx Oct 1, 2024
24157c6
densenet
thibaultdvx Oct 9, 2024
7e2dbbf
resnet
thibaultdvx Oct 9, 2024
2d60d14
senet
thibaultdvx Oct 9, 2024
3a31b97
first try vit
thibaultdvx Oct 10, 2024
dd37be7
vit first try
thibaultdvx Oct 10, 2024
038a8e8
vit new version
thibaultdvx Oct 11, 2024
f314519
unet
thibaultdvx Oct 11, 2024
25a6026
attention unet
thibaultdvx Oct 11, 2024
8174da2
layers
thibaultdvx Oct 14, 2024
63cddac
small modifications on well-known networks
thibaultdvx Oct 14, 2024
ea22a45
reorganize
thibaultdvx Oct 14, 2024
745f580
uniformization
thibaultdvx Oct 14, 2024
c1b6fdf
add adaptive pooling
thibaultdvx Oct 14, 2024
26a0fa8
change layer names
thibaultdvx Oct 14, 2024
2b5f189
remove input_shape for convencoder and convdecoder
thibaultdvx Oct 15, 2024
6a13dcb
enable initial and final (un)pooling
thibaultdvx Oct 15, 2024
1cd7597
add transposed conv for unpooling
thibaultdvx Oct 16, 2024
5a06a3e
solve sequence issue
thibaultdvx Oct 16, 2024
0dc0600
add check for adn ordering
thibaultdvx Oct 17, 2024
361532f
config classes
thibaultdvx Oct 17, 2024
8c6dddb
modift factory function
thibaultdvx Oct 17, 2024
18409aa
test config
thibaultdvx Oct 17, 2024
475ca50
factory function
thibaultdvx Oct 18, 2024
aa36a2b
tuple typing issue
thibaultdvx Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clinicadl/monai_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .config import ImplementedNetworks, NetworkConfig, create_network_config
from .factory import get_network
from .config import ImplementedNetworks, NetworkConfig
from .factory import get_network, get_network_from_config
3 changes: 1 addition & 2 deletions clinicadl/monai_networks/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .base import NetworkConfig
from .base import ImplementedNetworks, NetworkConfig, NetworkType
from .factory import create_network_config
from .utils.enum import ImplementedNetworks
92 changes: 24 additions & 68 deletions clinicadl/monai_networks/config/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,45 @@
from typing import Optional, Tuple, Union
from typing import Optional, Sequence, Union

from pydantic import (
NonNegativeInt,
PositiveInt,
computed_field,
model_validator,
)
from pydantic import PositiveInt, computed_field

from clinicadl.monai_networks.nn.layers.utils import (
ActivationParameters,
UnpoolingMode,
)
from clinicadl.utils.factories import DefaultFromLibrary

from .base import VaryingDepthNetworkConfig
from .utils.enum import ImplementedNetworks

__all__ = ["AutoEncoderConfig", "VarAutoEncoderConfig"]

from .base import ImplementedNetworks, NetworkConfig
from .conv_encoder import ConvEncoderOptions
from .mlp import MLPOptions

class AutoEncoderConfig(VaryingDepthNetworkConfig):
"""Config class for autoencoders."""

spatial_dims: PositiveInt
in_channels: PositiveInt
out_channels: PositiveInt
class AutoEncoderConfig(NetworkConfig):
"""Config class for AutoEncoder."""

inter_channels: Union[
Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary
] = DefaultFromLibrary.YES
inter_dilations: Union[
Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary
in_shape: Sequence[PositiveInt]
latent_size: PositiveInt
conv_args: ConvEncoderOptions
mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES
out_channels: Union[
Optional[PositiveInt], DefaultFromLibrary
] = DefaultFromLibrary.YES
num_inter_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
padding: Union[
Optional[Union[PositiveInt, Tuple[PositiveInt, ...]]], DefaultFromLibrary
output_act: Union[
Optional[ActivationParameters], DefaultFromLibrary
] = DefaultFromLibrary.YES
unpooling_mode: Union[UnpoolingMode, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def network(self) -> ImplementedNetworks:
def name(self) -> ImplementedNetworks:
"""The name of the network."""
return ImplementedNetworks.AE

@computed_field
@property
def dim(self) -> int:
"""Dimension of the images."""
return self.spatial_dims

@model_validator(mode="after")
def model_validator(self):
"""Checks coherence between parameters."""
if self.padding != DefaultFromLibrary.YES:
assert self._check_dimensions(
self.padding
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for padding. You passed {self.padding}."
if isinstance(self.inter_channels, tuple) and isinstance(
self.inter_dilations, tuple
):
assert len(self.inter_channels) == len(
self.inter_dilations
), "inter_channels and inter_dilations muust have the same size."
elif isinstance(self.inter_dilations, tuple) and not isinstance(
self.inter_channels, tuple
):
raise ValueError(
"You passed inter_dilations but didn't pass inter_channels."
)
return self


class VarAutoEncoderConfig(AutoEncoderConfig):
"""Config class for variational autoencoders."""

in_shape: Tuple[PositiveInt, ...]
in_channels: Optional[int] = None
latent_size: PositiveInt
use_sigmoid: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
class VAEConfig(AutoEncoderConfig):
"""Config class for Variational AutoEncoder."""

@computed_field
@property
def network(self) -> ImplementedNetworks:
def name(self) -> ImplementedNetworks:
"""The name of the network."""
return ImplementedNetworks.VAE

@model_validator(mode="after")
def model_validator_bis(self):
"""Checks coherence between parameters."""
assert (
len(self.in_shape[1:]) == self.spatial_dims
), f"You passed {self.spatial_dims} for spatial_dims, but in_shape suggests {len(self.in_shape[1:])} spatial dimensions."
210 changes: 70 additions & 140 deletions clinicadl/monai_networks/config/base.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,98 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union

from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
computed_field,
field_validator,
model_validator,
)
from typing import Optional, Union

from pydantic import BaseModel, ConfigDict, PositiveInt, computed_field

from clinicadl.monai_networks.nn.layers.utils import ActivationParameters
from clinicadl.utils.factories import DefaultFromLibrary

from .utils.enum import (
ImplementedActFunctions,
ImplementedNetworks,
ImplementedNormLayers,
)

class ImplementedNetworks(str, Enum):
"""Implemented neural networks in ClinicaDL."""

MLP = "MLP"
CONV_ENCODER = "ConvEncoder"
CONV_DECODER = "ConvDecoder"
CNN = "CNN"
GENERATOR = "Generator"
AE = "AutoEncoder"
VAE = "VAE"
DENSENET = "DenseNet"
DENSENET_121 = "DenseNet-121"
DENSENET_161 = "DenseNet-161"
DENSENET_169 = "DenseNet-169"
DENSENET_201 = "DenseNet-201"
RESNET = "ResNet"
RESNET_18 = "ResNet-18"
RESNET_34 = "ResNet-34"
RESNET_50 = "ResNet-50"
RESNET_101 = "ResNet-101"
RESNET_152 = "ResNet-152"
SE_RESNET = "SEResNet"
SE_RESNET_50 = "SEResNet-50"
SE_RESNET_101 = "SEResNet-101"
SE_RESNET_152 = "SEResNet-152"
UNET = "UNet"
ATT_UNET = "AttentionUNet"
VIT = "ViT"
VIT_B_16 = "ViT-B/16"
VIT_B_32 = "ViT-B/32"
VIT_L_16 = "ViT-L/16"
VIT_L_32 = "ViT-L/32"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented neural networks are: "
+ ", ".join([repr(m.value) for m in cls])
)


class NetworkType(str, Enum):
"""
Useful to know where to look for the network.
See :py:func:`clinicadl.monai_networks.factory.get_network`
"""

CUSTOM = "custom" # our own networks
RESNET = "sota-ResNet"
DENSENET = "sota-DenseNet"
SE_RESNET = "sota-SEResNet"
VIT = "sota-ViT"


class NetworkConfig(BaseModel, ABC):
"""Base config class to configure neural networks."""

kernel_size: Union[
PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary
] = DefaultFromLibrary.YES
up_kernel_size: Union[
PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary
] = DefaultFromLibrary.YES
num_res_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
act: Union[
ImplementedActFunctions,
Tuple[ImplementedActFunctions, Dict[str, Any]],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
norm: Union[
ImplementedNormLayers,
Tuple[ImplementedNormLayers, Dict[str, Any]],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
adn_ordering: Union[Optional[str], DefaultFromLibrary] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True,
use_enum_values=True,
validate_default=True,
protected_namespaces=(),
)

@computed_field
@property
@abstractmethod
def network(self) -> ImplementedNetworks:
def name(self) -> ImplementedNetworks:
"""The name of the network."""

@computed_field
@property
@abstractmethod
def dim(self) -> int:
"""Dimension of the images."""
def _type(self) -> NetworkType:
"""
To know where to look for the network.
Default to 'custom'.
"""
return NetworkType.CUSTOM

@classmethod
def base_validator_dropout(cls, v):
"""Checks that dropout is between 0 and 1."""
if isinstance(v, float):
assert (
0 <= v <= 1
), f"dropout must be between 0 and 1 but it has been set to {v}."
return v

@field_validator("kernel_size", "up_kernel_size")
@classmethod
def base_is_odd(cls, value, field):
"""Checks if a field is odd."""
if value != DefaultFromLibrary.YES:
if isinstance(value, int):
value_ = (value,)
else:
value_ = value
for v in value_:
assert v % 2 == 1, f"{field.field_name} must be odd."
return value

@field_validator("adn_ordering", mode="after")
@classmethod
def base_adn_validator(cls, v):
"""Checks ADN sequence."""
if v != DefaultFromLibrary.YES:
for letter in v:
assert (
letter in {"A", "D", "N"}
), f"adn_ordering must be composed by 'A', 'D' or/and 'N'. You passed {letter}."
assert len(v) == len(
set(v)
), "adn_ordering cannot contain duplicated letter."

return v

@classmethod
def base_at_least_2d(cls, v, ctx):
"""Checks that a tuple has at least a length of two."""
if isinstance(v, tuple):
assert (
len(v) >= 2
), f"{ctx.field_name} should have at least two dimensions (with the first one for the channel)."
return v

@model_validator(mode="after")
def base_model_validator(self):
"""Checks coherence between parameters."""
if self.kernel_size != DefaultFromLibrary.YES:
assert self._check_dimensions(
self.kernel_size
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for kernel_size. You passed {self.kernel_size}."
if self.up_kernel_size != DefaultFromLibrary.YES:
assert self._check_dimensions(
self.up_kernel_size
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for up_kernel_size. You passed {self.up_kernel_size}."
return self

def _check_dimensions(
self,
value: Union[float, Tuple[float, ...]],
) -> bool:
"""Checks if a tuple has the right dimension."""
if isinstance(value, tuple):
return len(value) == self.dim
return True


class VaryingDepthNetworkConfig(NetworkConfig, ABC):
"""
Base config class to configure neural networks.
More precisely, we refer to MONAI's networks with 'channels' and 'strides' parameters.
"""
class PreTrainedConfig(NetworkConfig):
"""Base config class for SOTA networks."""

channels: Tuple[PositiveInt, ...]
strides: Tuple[Union[PositiveInt, Tuple[PositiveInt, ...]], ...]
dropout: Union[
Optional[NonNegativeFloat], DefaultFromLibrary
num_outputs: Optional[PositiveInt]
output_act: Union[
Optional[ActivationParameters], DefaultFromLibrary
] = DefaultFromLibrary.YES

@field_validator("dropout")
@classmethod
def validator_dropout(cls, v):
"""Checks that dropout is between 0 and 1."""
return cls.base_validator_dropout(v)

@model_validator(mode="after")
def channels_strides_validator(self):
"""Checks coherence between parameters."""
n_layers = len(self.channels)
assert (
len(self.strides) == n_layers
), f"There are {n_layers} layers but you passed {len(self.strides)} strides."
for s in self.strides:
assert self._check_dimensions(
s
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for strides. You passed {s}."

return self
pretrained: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
Loading
Loading