diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index e67cb3376f..afb6664bd9 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -30,6 +30,7 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock +from .spade_norm import SPADE from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py new file mode 100644 index 0000000000..918a02de6e --- /dev/null +++ b/monai/networks/blocks/spade_norm.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import ADN, Convolution + + +class SPADE(nn.Module): + """ + SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291) + + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation + """ + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict | None = None, + ) -> None: + super().__init__() + + if norm_params is None: + norm_params = {} + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = ADN( + act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc + ) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + padding=kernel_size // 2, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding=kernel_size // 2, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + padding=kernel_size // 2, + act=None, + ) + + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor + segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..bd3e3af3af 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -37,4 +37,5 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py new file mode 100644 index 0000000000..c84aada32a --- /dev/null +++ b/monai/networks/layers/vector_quantizer.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from torch import nn + +__all__ = ["VectorQuantizer", "EMAQuantizer"] + + +class EMAQuantizer(nn.Module): + """ + Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural + Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation + that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit + 58d9a2746493717a7c9252938da7efa6006f3739. + + This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due + to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 + on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + + Args: + spatial_dims : number of spatial spatial_dims. + num_embeddings: number of atomic elements in the codebook. + embedding_dim: number of channels of the input and atomic elements. + commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. + decay: EMA decay. Defaults to 0.99. + epsilon: epsilon value. Defaults to 1e-5. + embedding_init: initialization method for the codebook. Defaults to "normal". + ddp_sync: whether to synchronize the codebook across processes. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_embeddings: int, + embedding_dim: int, + commitment_cost: float = 0.25, + decay: float = 0.99, + epsilon: float = 1e-5, + embedding_init: str = "normal", + ddp_sync: bool = True, + ): + super().__init__() + self.spatial_dims: int = spatial_dims + self.embedding_dim: int = embedding_dim + self.num_embeddings: int = num_embeddings + + assert self.spatial_dims in [2, 3], ValueError( + f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." + ) + + self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) + if embedding_init == "normal": + # Initialization is passed since the default one is normal inside the nn.Embedding + pass + elif embedding_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") + self.embedding.weight.requires_grad = False + + self.commitment_cost: float = commitment_cost + + self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + # declare types for mypy + self.ema_cluster_size: torch.Tensor + self.ema_w: torch.Tensor + self.decay: float = decay + self.epsilon: float = epsilon + + self.ddp_sync: bool = ddp_sync + + # Precalculating required permutation shapes + self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] + self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( + range(1, self.spatial_dims + 1) + ) + + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. + + Args: + inputs: Encoding space tensors + + Returns: + torch.Tensor: Flatten version of the input of shape [B*D*H*W, C]. + torch.Tensor: One-hot representation of the quantization indices of shape [B*D*H*W, self.num_embeddings]. + torch.Tensor: Quantization indices of shape [B,D,H,W,1] + + """ + with torch.autocast(device_type="cuda", enabled=False): + encoding_indices_view = list(inputs.shape) + del encoding_indices_view[1] + + inputs = inputs.float() + + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) + + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) + + return flat_input, encodings, encoding_indices + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + """ + Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space + [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the + decoder. + + Args: + embedding_indices: Tensor in channel last format which holds indices referencing atomic + elements from self.embedding + + Returns: + torch.Tensor: Quantize space representation of encoding_indices in channel first format. + """ + with torch.autocast(device_type="cuda", enabled=False): + embedding: torch.Tensor = ( + self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + ) + return embedding + + def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: + """ + TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the + example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + + Args: + encodings_sum: The summation of one hot representation of what encoding was used for each + position. + dw: The multiplication of the one hot representation of what encoding was used for each + position with the flattened input. + + Returns: + None + """ + if self.ddp_sync and torch.distributed.is_initialized(): + torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) + else: + pass + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_input, encodings, encoding_indices = self.quantize(inputs) + quantized = self.embed(encoding_indices) + + # Use EMA to update the embedding vectors + if self.training: + with torch.no_grad(): + encodings_sum = encodings.sum(0) + dw = torch.mm(encodings.t(), flat_input) + + if self.ddp_sync: + self.distributed_synchronization(encodings_sum, dw) + + self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) + + # Laplace smoothing of the cluster size + n = self.ema_cluster_size.sum() + weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n + self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) + self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) + + # Encoding Loss + loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + + return quantized, loss, encoding_indices + + +class VectorQuantizer(torch.nn.Module): + """ + Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of + the quantization in their own class. + + Args: + quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index + based quantized representation. + """ + + def __init__(self, quantizer: EMAQuantizer): + super().__init__() + + self.quantizer: EMAQuantizer = quantizer + + self.perplexity: torch.Tensor = torch.rand(1) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + quantized, loss, encoding_indices = self.quantizer(inputs) + # Perplexity calculations + avg_probs = ( + torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) + .float() + .div(encoding_indices.numel()) + ) + + self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return loss, quantized + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.quantizer.embed(embedding_indices=embedding_indices) + + def quantize(self, encodings: torch.Tensor) -> torch.Tensor: + output = self.quantizer(encodings) + encoding_indices: torch.Tensor = output[2] + return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..a7ce16ad64 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,9 +14,11 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, @@ -34,6 +36,7 @@ densenet201, densenet264, ) +from .diffusion_model_unet import DiffusionModelUNet from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( @@ -52,6 +55,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet @@ -102,9 +106,12 @@ seresnext50, seresnext101, ) +from .spade_autoencoderkl import SPADEAutoencoderKL +from .spade_diffusion_model_unet import SPADEDiffusionModelUNet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder @@ -112,3 +119,4 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py new file mode 100644 index 0000000000..acdff6c96a --- /dev/null +++ b/monai/networks/nets/autoencoderkl.py @@ -0,0 +1,816 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import math +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.utils import ensure_tuple_rep + +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["AutoencoderKL"] + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + return self.conv(x) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + x = self.conv(x) + return x + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class _ResBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = F.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Attention block. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +class Encoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + num_channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) + ) + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + reversed_block_out_channels = list(reversed(num_channels)) + + blocks = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class AutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + num_channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpointing: if True, use activation checkpointing to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channels=num_channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = Decoder( + spatial_dims=spatial_dims, + num_channels=num_channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + self.use_checkpointing = use_checkpointing + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + if self.use_checkpointing: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) + + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu) + return reconstruction + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + if self.use_checkpointing: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) + return dec + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + image = self.decode(z) + return image diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py new file mode 100644 index 0000000000..ace44cba0a --- /dev/null +++ b/monai/networks/nets/controlnet.py @@ -0,0 +1,416 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Network to encode the conditioning into a latent space. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, num_channels: Sequence[int] = (16, 32, 96, 256) + ): + super().__init__() + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(num_channels) - 1): + channel_in = num_channels[i] + channel_out = num_channels[i + 1] + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_in, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_out, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.conv_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNet(nn.Module): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + num_channels=conditioning_embedding_num_channels, + out_channels=num_channels[0], + ) + + # down + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + # + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = num_channels[-1] + + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). + conditioning_scale: conditioning scale. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py new file mode 100644 index 0000000000..4996005f21 --- /dev/null +++ b/monai/networks/nets/diffusion_model_unet.py @@ -0,0 +1,2144 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import importlib.util +import math +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep + +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["DiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class _CrossAttention(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + return self.to_out(x) + + +class _BasicTransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = _CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = _CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class _SpatialTransformer(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + _BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + return self.op(x) + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + self.conv = None + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class _ResnetBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + + return self.skip_connection(x) + h + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class DiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + h = self.out(h) + + return h + + +class DiffusionModelEncoder(nn.Module): + """ + Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on + Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) # - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + for downsample_block in self.down_blocks: + h, _ = downsample_block(hidden_states=h, temb=emb, context=context) + + h = h.reshape(h.shape[0], -1) + output = self.out(h) + + return output diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py new file mode 100644 index 0000000000..7fd9aaf491 --- /dev/null +++ b/monai/networks/nets/patchgan_discriminator.py @@ -0,0 +1,255 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act + + +class MultiScalePatchDiscriminator(nn.Sequential): + """ + Multi-scale Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs + Ting-Chun Wang1, Ming-Yu Liu1, Jun-Yan Zhu2, Andrew Tao1, Jan Kautz1, Bryan Catanzaro (1) + (1) NVIDIA Corporation, 2UC Berkeley + In CVPR 2018. + Multi-Scale discriminator made up of several Patch-GAN discriminators, that process the images + up to different spatial scales. + + Args: + num_d: number of discriminators + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in each + of the discriminators. In each layer, the number of channels are doubled and the spatial size is + divided by 2. + spatial_dims: number of spatial dimensions (1D, 2D etc.) + num_channels: number of filters in the first convolutional layer (double of the value is taken from then on) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + dropout: proportion of dropout applied, defaults to 0. + minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture + requested isn't going to downsample the input image beyond value of 1. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + num_d: int, + num_layers_d: int, + spatial_dims: int, + num_channels: int, + in_channels: int, + out_channels: int = 1, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + dropout: float | tuple = 0.0, + minimum_size_im: int = 256, + last_conv_kernel_size: int = 1, + ) -> None: + super().__init__() + self.num_d = num_d + self.num_layers_d = num_layers_d + self.num_channels = num_channels + self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) + for i_ in range(self.num_d): + num_layers_d_i = self.num_layers_d * (i_ + 1) + output_size = float(minimum_size_im) / (2**num_layers_d_i) + if output_size < 1: + raise AssertionError( + "Your image size is too small to take in up to %d discriminators with num_layers = %d." + "Please reduce num_layers, reduce num_D or enter bigger images." % (i_, num_layers_d_i) + ) + subnet_d = PatchDiscriminator( + spatial_dims=spatial_dims, + num_channels=self.num_channels, + in_channels=in_channels, + out_channels=out_channels, + num_layers_d=num_layers_d_i, + kernel_size=kernel_size, + activation=activation, + norm=norm, + bias=bias, + padding=self.padding, + dropout=dropout, + last_conv_kernel_size=last_conv_kernel_size, + ) + + self.add_module("discriminator_%d" % i_, subnet_d) + + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: + """ + + Args: + i: Input tensor + Returns: + list of outputs and another list of lists with the intermediate features + of each discriminator. + """ + + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] + for disc in self.children(): + out_d: list[torch.Tensor] = disc(i) + out.append(out_d[-1]) + intermediate_features.append(out_d[:-1]) + + return out, intermediate_features + + +class PatchDiscriminator(nn.Sequential): + """ + Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs + Ting-Chun Wang1, Ming-Yu Liu1, Jun-Yan Zhu2, Andrew Tao1, Jan Kautz1, Bryan Catanzaro (1) + (1) NVIDIA Corporation, 2UC Berkeley + In CVPR 2018. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D etc.) + num_channels: number of filters in the first convolutional layer (double of the value is taken from then on) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in each + of the discriminators. In each layer, the number of channels are doubled and the spatial size is + divided by 2. + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + padding: padding to be applied to the convolutional layers + dropout: proportion of dropout applied, defaults to 0. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + in_channels: int, + out_channels: int = 1, + num_layers_d: int = 3, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, + ) -> None: + super().__init__() + self.num_layers_d = num_layers_d + self.num_channels = num_channels + if last_conv_kernel_size is None: + last_conv_kernel_size = kernel_size + + self.add_module( + "initial_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=in_channels, + out_channels=num_channels, + act=activation, + bias=True, + norm=None, + dropout=dropout, + padding=padding, + strides=2, + ), + ) + + input_channels = num_channels + output_channels = num_channels * 2 + + # Initial Layer + for l_ in range(self.num_layers_d): + if l_ == self.num_layers_d - 1: + stride = 1 + else: + stride = 2 + layer = Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + act=activation, + bias=bias, + norm=norm, + dropout=dropout, + padding=padding, + strides=stride, + ) + self.add_module("%d" % l_, layer) + input_channels = output_channels + output_channels = output_channels * 2 + + # Final layer + self.add_module( + "final_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=last_conv_kernel_size, + in_channels=input_channels, + out_channels=out_channels, + bias=True, + conv_only=True, + padding=int((last_conv_kernel_size - 1) / 2), + dropout=0.0, + strides=1, + ), + ) + + self.apply(self.initialise_weights) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + + Args: + x: input tensor + feature-matching loss (regulariser loss) on the discriminators as well (see Pix2Pix paper). + Returns: + list of intermediate features, with the last element being the output. + """ + out = [x] + for submodel in self.children(): + intermediate_output = submodel(out[-1]) + out.append(intermediate_output) + + return out[1:] + + def initialise_weights(self, m: nn.Module) -> None: + """ + Initialise weights of Convolution and BatchNorm layers. + + Args: + m: instance of torch.nn.module (or of class inheriting torch.nn.module) + """ + classname = m.__class__.__name__ + if classname.find("Conv2d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv3d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv1d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py new file mode 100644 index 0000000000..7c3826ae23 --- /dev/null +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -0,0 +1,484 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.autoencoderkl import Encoder, _AttentionBlock, _Upsample +from monai.utils import ensure_tuple_rep + +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + + has_xformers = True +else: + xformers = None + has_xformers = False + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["SPADEAutoencoderKL"] + + +class SPADEResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h, seg) + h = F.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class SPADEDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + label_nc: number of semantic channels for SPADE normalisation. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + label_nc: int, + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.label_nc = label_nc + + reversed_block_out_channels = list(reversed(num_channels)) + + blocks = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=False)) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + if isinstance(block, SPADEResBlock): + x = block(x, seg) + else: + x = block(x) + return x + + +class SPADEAutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + label_nc: number of semantic channels for SPADE normalisation. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + num_channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + label_nc: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("SPADEAutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("SPADEAutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channels=num_channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + num_channels=num_channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + label_nc=label_nc, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + h = self.encoder(x) + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu, seg) + return reconstruction + + def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec = self.decoder(z, seg) + return dec + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z, seg) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + image = self.decode(z, seg) + return image diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py new file mode 100644 index 0000000000..819ac015a1 --- /dev/null +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -0,0 +1,912 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import importlib.util +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.diffusion_model_unet import ( + _AttentionBlock, + _Downsample, + _ResnetBlock, + _SpatialTransformer, + _Upsample, + get_down_block, + get_mid_block, + get_timestep_embedding, + zero_module, +) +from monai.utils import ensure_tuple_rep + +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + + has_xformers = True +else: + xformers = None + has_xformers = False + + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["SPADEDiffusionModelUNet"] + + +class SPADEResnetBlock(nn.Module): + """ + Residual block with timestep conditioning and SPADE norm. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + label_nc: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=self.out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h, seg) + h = self.nonlinearity(h) + h = self.conv2(h) + + return self.skip_connection(x) + h + + +class SPADEUpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADEAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADECrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + label_nc: number of semantic channels for SPADE normalisation. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + label_nc: int | None = None, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor | None = None, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_spade_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + label_nc: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, +) -> nn.Module: + if with_attn: + return SPADEAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + elif with_cross_attn: + return SPADECrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + else: + return SPADEUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + spade_intermediate_channels=spade_intermediate_channels, + ) + + +class SPADEDiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + self.label_nc = label_nc + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_spade_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context) + + # 7. output block + h = self.out(h) + + return h diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py new file mode 100644 index 0000000000..3d976d6950 --- /dev/null +++ b/monai/networks/nets/transformer.py @@ -0,0 +1,320 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.mlp import MLPBlock + +if importlib.util.find_spec("xformers") is not None: + import xformers.ops as xops + + has_xformers = True +else: + has_xformers = False +__all__ = ["DecoderOnlyTransformer"] + + +class _SABlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal + self.sequence_length = sequence_length + self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + # key, query, value projections + self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + # regularization + self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + + # output projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query = self.to_q(x) + + kv = context if context is not None else x + _, kv_t, _ = kv.size() + key = self.to_k(kv) + value = self.to_v(kv) + + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + + else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + + # manual implementation of attention + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) + + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) + + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y + + +class _TransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: apply bias term for the qkv linear layer + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + self.with_cross_attention = with_cross_attention + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=causal, + sequence_length=sequence_length, + use_flash_attention=use_flash_attention, + ) + + self.norm2 = None + self.cross_attn = None + if self.with_cross_attention: + self.norm2 = nn.LayerNorm(hidden_size) + self.cross_attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + with_cross_attention=with_cross_attention, + causal=False, + use_flash_attention=use_flash_attention, + ) + + self.norm3 = nn.LayerNorm(hidden_size) + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm2(x), context=context) + x = x + self.mlp(self.norm3(x)) + return x + + +class AbsolutePositionalEmbedding(nn.Module): + """Absolute positional embedding. + + Args: + max_seq_len: Maximum sequence length. + embedding_dim: Dimensionality of the embedding. + """ + + def __init__(self, max_seq_len: int, embedding_dim: int) -> None: + super().__init__() + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.embedding = nn.Embedding(max_seq_len, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = x.size() + positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) + return self.embedding(positions) + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + self.with_cross_attention = with_cross_attention + + self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) + self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) + self.embedding_dropout = nn.Dropout(embedding_dropout_rate) + + self.blocks = nn.ModuleList( + [ + _TransformerBlock( + hidden_size=attn_layers_dim, + mlp_dim=attn_layers_dim * 4, + num_heads=attn_layers_heads, + dropout_rate=0.0, + qkv_bias=False, + causal=True, + sequence_length=max_seq_len, + with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(attn_layers_depth) + ] + ) + + self.to_logits = nn.Linear(attn_layers_dim, num_tokens) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tok_emb = self.token_embeddings(x) + pos_emb = self.position_embeddings(x) + x = self.embedding_dropout(tok_emb + pos_emb) + + for block in self.blocks: + x = block(x, context=context) + + return self.to_logits(x) diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py new file mode 100644 index 0000000000..0a37e06ea4 --- /dev/null +++ b/monai/networks/nets/vqvae.py @@ -0,0 +1,456 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer +from monai.utils import ensure_tuple_rep + +__all__ = ["VQVAE"] + + +class VQVAEResidualUnit(nn.Module): + """ + Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving + Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). + + The original implementation that can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. + + Args: + spatial_dims: number of spatial spatial_dims of the input data. + num_channels: number of input channels. + num_res_channels: number of channels in the residual layers. + act: activation type and arguments. Defaults to RELU. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_res_channels: int, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.num_res_channels = num_res_channels + self.act = act + self.dropout = dropout + self.bias = bias + + self.conv1 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_channels, + out_channels=self.num_res_channels, + adn_ordering="DA", + act=self.act, + dropout=self.dropout, + bias=self.bias, + ) + + self.conv2 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_res_channels, + out_channels=self.num_channels, + bias=self.bias, + conv_only=True, + ) + + def forward(self, x): + return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) + + +class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of channels in the latent space (embedding_dim). + num_channels: number of channels at each level. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + dropout: dropout ratio. + act: activation type and arguments. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Sequence[int, int, int, int], ...], + dropout: float, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.dropout = dropout + self.act = act + + blocks = [] + + for i in range(len(self.num_channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.num_channels[i - 1], + out_channels=self.num_channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + num_channels=self.num_channels[i], + num_res_channels=self.num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_channels[len(self.num_channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + num_channels: number of channels at each level. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Sequence[int, int, int, int], ...], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.num_channels)) + + blocks = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.num_channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + num_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.num_channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=self.dropout if i != len(self.num_channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.num_channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) + + The original implementation can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + num_channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channels: Sequence[int] | int = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Sequence[int, int, int, int], ...] + | Sequence[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Sequence[int, int, int, int, int], ...] + | Sequence[int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + use_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(num_channels)) + + if len(num_res_channels) != len(num_channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if not all(isinstance(values, (int, Sequence)) for values in downsample_parameters): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + if not all(isinstance(values, (int, Sequence)) for values in upsample_parameters): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters = (upsample_parameters,) * len(num_channels) + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters = (downsample_parameters,) * len(num_channels) + + for parameter in downsample_parameters: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters) != len(num_channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters) != len(num_channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + num_channels=num_channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + num_channels=num_channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) + + def encode(self, images: torch.Tensor) -> torch.Tensor: + if self.use_checkpointing: + return torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + return self.encoder(images) + + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_loss, x = self.quantizer(encodings) + return x, x_loss + + def decode(self, quantizations: torch.Tensor) -> torch.Tensor: + if self.use_checkpointing: + return torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + return self.decoder(quantizations) + + def index_quantize(self, images: torch.Tensor) -> torch.Tensor: + return self.quantizer.quantize(self.encode(images=images)) + + def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.decode(self.quantizer.embed(embedding_indices)) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantizations, quantization_losses = self.quantize(self.encode(images)) + reconstruction = self.decode(quantizations) + + return reconstruction, quantization_losses + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z = self.encode(x) + e, _ = self.quantize(z) + return e + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + e, _ = self.quantize(z) + image = self.decode(e) + return image diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py new file mode 100644 index 0000000000..58943a1626 --- /dev/null +++ b/tests/test_autoencoderkl.py @@ -0,0 +1,270 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import AutoencoderKL + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py new file mode 100644 index 0000000000..27acccc083 --- /dev/null +++ b/tests/test_controlnet.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.controlnet import ControlNet + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "conditioning_embedding_in_channels": 1, + "conditioning_embedding_num_channels": (8, 8), + }, + 6, + (1, 8, 4, 4), + ] +] + + +class TestControlNet(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): + net = ControlNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 32, 32)) + ) + self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) + self.assertEqual(result[1].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py new file mode 100644 index 0000000000..77ae1afcaa --- /dev/null +++ b/tests/test_diffusion_model_unet.py @@ -0,0 +1,535 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + } + ], +] + +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0, + } + ] +] + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_num_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + def test_right_dropout(self, input_param): + _ = DiffusionModelUNet(**input_param) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_patch_gan.py b/tests/test_patch_gan.py new file mode 100644 index 0000000000..162fa5e8fc --- /dev/null +++ b/tests/test_patch_gan.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.patchgan_discriminator import MultiScalePatchDiscriminator +from tests.utils import test_script_save + +TEST_2D = [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 2, + "num_channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512]), + [(1, 1, 32, 64), (1, 1, 4, 8)], + [4, 7], +] +TEST_3D = [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 3, + "num_channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512, 256]), + [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], + [4, 7], +] +TEST_TOO_SMALL_SIZE = [ + { + "num_d": 2, + "num_layers_d": 6, + "spatial_dims": 2, + "num_channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + } +] + +CASES = [TEST_2D, TEST_3D] + + +class TestPatchGAN(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): + net = MultiScalePatchDiscriminator(**input_param) + with eval_mode(net): + result, features = net.forward(input_data) + for r_ind, r in enumerate(result): + self.assertEqual(tuple(r.shape), expected_shape[r_ind]) + for o_d_ind, o_d in enumerate(features): + self.assertEqual(len(o_d), features_lengths[o_d_ind]) + + def test_too_small_shape(self): + with self.assertRaises(AssertionError): + MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) + + def test_script(self): + net = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + num_channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + minimum_size_im=256, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py new file mode 100644 index 0000000000..6de7d6b67c --- /dev/null +++ b/tests/test_spade_autoencoderkl.py @@ -0,0 +1,260 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEAutoencoderKL + +CASES = [ + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + "spade_intermediate_channels": 32, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], +] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class TestSPADEAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape): + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_encode(self): + input_param, input_shape, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + def test_wrong_shape_decode(self): + net = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, False), + num_res_blocks=1, + norm_num_groups=4, + ) + with self.assertRaises(RuntimeError): + _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py new file mode 100644 index 0000000000..c8a2103cf6 --- /dev/null +++ b/tests/test_spade_diffusion_model_unet.py @@ -0,0 +1,558 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEDiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + "spade_intermediate_channels": 256, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + "label_nc": 3, + } + ], +] + + +class TestSPADEDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) + ) + + def test_label_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(RuntimeError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16)) + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_num_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, 3, 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 16, 16)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..3dd2d6621a --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks import eval_mode +from monai.networks.nets import DecoderOnlyTransformer + + +class TestDecoderOnlyTransformer(unittest.TestCase): + def test_unconditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + def test_conditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + embedding_dropout_rate=0, + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py new file mode 100644 index 0000000000..7f7e7a4fd7 --- /dev/null +++ b/tests/test_vector_quantizer.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import EMAQuantizer, VectorQuantizer + + +class TestEMA(unittest.TestCase): + def test_ema_shape(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8) + input_shape = (1, 8, 8, 8) + x = torch.randn(input_shape) + layer = layer.train() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, (1, 8, 8)) + + layer = layer.eval() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, (1, 8, 8)) + + def test_ema_quantize(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8) + input_shape = (1, 8, 8, 8) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs[0].shape, (64, 8)) # (HxW, C) + self.assertEqual(outputs[1].shape, (64, 16)) # (HxW, E) + self.assertEqual(outputs[2].shape, (1, 8, 8)) # (1, H, W) + + def test_ema(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) + original_weight_0 = layer.embedding.weight[0].clone() + original_weight_1 = layer.embedding.weight[1].clone() + x_0 = original_weight_0 + x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 + + x_1 = original_weight_1 + x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_1 = x_1.repeat(1, 1, 1, 2) + + x = torch.cat([x_0, x_1], dim=0) + layer = layer.train() + _ = layer(x) + + self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) + self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) + + +class TestVectorQuantizer(unittest.TestCase): + def test_vector_quantizer_shape(self): + layer = VectorQuantizer(EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8)) + input_shape = (1, 8, 8, 8) + x = torch.randn(input_shape) + outputs = layer(x) + self.assertEqual(outputs[1].shape, input_shape) + + def test_vector_quantizer_quantize(self): + layer = VectorQuantizer(EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8)) + input_shape = (1, 8, 8, 8) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs.shape, (1, 8, 8)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py new file mode 100644 index 0000000000..c0f6426e23 --- /dev/null +++ b/tests/test_vqvae.py @@ -0,0 +1,272 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vqvae import VQVAE + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], +] + +TEST_LATENT_SHAPE = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_channels": (8, 8), + "num_res_channels": (8, 8), + "num_embeddings": 16, + "embedding_dim": 8, +} + + +class TestVQVAE(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # num_channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) + + def test_num_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_num_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_num_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + + def test_encode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8, 8)) + + def test_index_quantize_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8)) + + def test_decode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + def test_decode_samples_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + +if __name__ == "__main__": + unittest.main()