-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restormer Implementation for MONAI: High-Resolution Image Restoration #8261
Comments
@Nic-Ma and @ericspod and @KumoLiu - this seems like an outstanding addition to MONAI - agreed? @phisanti - if all approve, please look at our contribution guidelines. You are already doing the exact right thing by having a modular design. Whenever appropriate, please support the exploration of alternative components in this framework via that modular design and appropriate class abstractions. Please also include multiple tutorials and unit tests with your work. Does your code currently exist in another repo that we could preliminarily review? Thanks! |
You can take a look at the modular implementation of the Restormer architecture here. Also copied the code below. As you can see, I maintain many of the key blocks intact and focus on expanding functionality (Flash att), and adding modularity on the enc/dec blocks. I am happy to implement extra changes if a good suggestion is made. """
Restormer: Efficient Transformer for High-Resolution Image Restoration
Implementation based on: https://arxiv.org/abs/2111.09881
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.layers import Norm
from einops import rearrange
class FeedForward(nn.Module):
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection."""
def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool):
super().__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
return self.project_out(F.gelu(x1) * x2)
class Attention(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention."""
def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
raise ValueError("Flash attention not available")
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self._attention_fn = self._get_attention_fn()
def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention
def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
dropout_p=0.0,
is_causal=False
)
return out
def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v
def forward(self, x):
"""Forward pass for MDTA attention.
1. Apply depth-wise convolutions to Q, K, V
2. Reshape Q, K, V for multi-head attention
3. Compute attention matrix using flash or normal attention
4. Reshape and project out attention output"""
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
out = self._attention_fn(q, k, v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class TransformerBlock(nn.Module):
"""Basic transformer unit combining MDTA and GDFN with skip connections.
Unlike standard transformers that use LayerNorm, this block uses Instance Norm
for better adaptation to image restoration tasks."""
def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float,
bias: bool, LayerNorm_type: str, flash_attention: bool = False):
super().__init__()
use_bias = LayerNorm_type != 'BiasFree'
self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
self.attn = Attention(dim, num_heads, bias, flash_attention)
self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
#print(f'x shape in transformer block: {x.shape}')
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""
def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class Downsample(nn.Module):
"""Downsampling module that halves spatial dimensions while doubling channels.
Uses PixelUnshuffle for efficient feature map manipulation."""
def __init__(self, n_feat: int):
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(n_feat, n_feat//2, kernel_size=3,
stride=1, padding=1, bias=False),
nn.PixelUnshuffle(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
class Upsample(nn.Module):
"""Upsampling module that doubles spatial dimensions while halving channels.
Combines convolution with PixelShuffle for efficient feature expansion."""
def __init__(self, in_channels: int) -> None:
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, in_channels * 2, kernel_size=3,
stride=1, padding=1, bias=False),
nn.PixelShuffle(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
##---------- Restormer -----------------------
class Restormer(nn.Module):
"""Restormer: Efficient Transformer for High-Resolution Image Restoration.
Implements a U-Net style architecture with transformer blocks, combining:
- Multi-scale feature processing through progressive down/upsampling
- Efficient attention via MDTA blocks
- Local feature mixing through GDFN
- Skip connections for preserving spatial details
Architecture:
- Encoder: Progressive feature downsampling with increasing channels
- Latent: Deep feature processing at lowest resolution
- Decoder: Progressive upsampling with skip connections
- Refinement: Final feature enhancement
"""
def __init__(self,
inp_channels=3,
out_channels=3,
dim=48,
num_blocks=[1, 1, 1, 1],
heads=[1, 1, 1, 1],
num_refinement_blocks=4,
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type='WithBias',
dual_pixel_task=False,
flash_attention=False):
super().__init__()
"""Initialize Restormer model.
Args:
inp_channels: Number of input image channels
out_channels: Number of output image channels
dim: Base feature dimension
num_blocks: Number of transformer blocks at each scale
num_refinement_blocks: Number of final refinement blocks
heads: Number of attention heads at each scale
ffn_expansion_factor: Expansion factor for feed-forward network
bias: Whether to use bias in convolutions
LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree')
dual_pixel_task: Enable dual-pixel specific processing
flash_attention: Use flash attention if available
"""
# Check input parameters
assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0"
# Initial feature extraction
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_levels = nn.ModuleList()
self.downsamples = nn.ModuleList()
self.decoder_levels = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.reduce_channels = nn.ModuleList()
num_steps = len(num_blocks) - 1
self.num_steps = num_steps
# Define encoder levels
for n in range(num_steps):
current_dim = dim * 2**n
self.encoder_levels.append(
nn.Sequential(*[
TransformerBlock(
dim=current_dim,
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[n])
])
)
self.downsamples.append(Downsample(current_dim))
# Define latent space
latent_dim = dim * 2**num_steps
self.latent = nn.Sequential(*[
TransformerBlock(
dim=latent_dim,
num_heads=heads[num_steps],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[num_steps])
])
# Define decoder levels
for n in reversed(range(num_steps)):
current_dim = dim * 2**n
next_dim = dim * 2**(n+1)
self.upsamples.append(Upsample(next_dim))
# Reduce channel layers to deal with skip connections
if n != 0:
self.reduce_channels.append(
nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias)
)
decoder_dim = current_dim
else:
decoder_dim = next_dim
self.decoder_levels.append(
nn.Sequential(*[
TransformerBlock(
dim=decoder_dim,
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[n])
])
)
# Final refinement and output
self.refinement = nn.Sequential(*[
TransformerBlock(
dim=decoder_dim,
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_refinement_blocks)
])
self.dual_pixel_task = dual_pixel_task
if self.dual_pixel_task:
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, x):
"""Forward pass of Restormer.
Processes input through encoder-decoder architecture with skip connections.
Args:
inp_img: Input image tensor of shape (B, C, H, W)
Returns:
Restored image tensor of shape (B, C, H, W)
"""
assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step"
# Patch embedding
x = self.patch_embed(x)
skip_connections = []
# Encoding path
for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
x = encoder(x)
skip_connections.append(x)
x = downsample(x)
# Latent space
x = self.latent(x)
# Decoding path
for idx in range(len(self.decoder_levels)):
x = self.upsamples[idx](x)
x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
if idx < len(self.decoder_levels) - 1:
x = self.reduce_channels[idx](x)
x = self.decoder_levels[idx](x)
# Final refinement
x = self.refinement(x)
if self.dual_pixel_task:
x = x + self.skip_conv(skip_connections[0])
x = self.output(x)
else:
x = self.output(x)
return x
if __name__ == "__main__":
flash_att = True
test_model = Restormer(
inp_channels=2,
out_channels=2,
dim=16,
num_blocks=[1,1,1,1],
heads=[1,1,1,1],
num_refinement_blocks=2,
ffn_expansion_factor=1.5,
bias=False,
LayerNorm_type='WithBias',
dual_pixel_task=True,
flash_attention=flash_att
)
print(f'flash attention set to {flash_att}')
input_tensor = torch.randn(8, 2, 256, 256)
print(f"Input shape: {input_tensor.shape}")
output = test_model(input_tensor)
print(f"Output shape: {output.shape}")
print(f'printing final model')
from torchsummary import summary
summary(test_model, input_size=input_tensor)
``` |
@aylward @Nic-Ma @ericspod and @KumoLiu, if you all agree and there is no comments on extra modules to be added, I will implement the class as it is. For that, I will:
What aspects of this approach would you modify to fully align with MONAI's contribution standards? |
Hi @phisanti, thank you for sharing the comprehensive plan! I’d recommend dividing the implementation into several PRs to simplify the review process. Additionally, I highly suggest checking if there are existing blocks in MONAI that can be reused in your network, such as upsample, downsample, attention mechanisms, etc. https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/downsample.py Also, consider using Convolution, which could make your network support both 2D and 3D implementations seamlessly.
|
Is your feature request related to a problem? Please describe.
I've noticed that MONAI currently lacks dedicated models for image denoising and restoration tasks. While MONAI provides excellent tools for medical image analysis, having specialized architectures for improving image quality would be valuable for preprocessing pipelines and enhancing low-quality medical images (microscopy, X-ray, scans...).
Describe the solution you'd like
I have implemented a well-documented version of the Restormer model (https://arxiv.org/abs/2111.09881) that could be contributed to MONAI. The implementation includes key components like:
Describe alternatives you've considered
The implementation is already structured in a modular way with clear separation of components. I'm willing to:
Additional context
The code is currently functional and tested. It supports both standard and dual-pixel tasks, with configurable parameters for network depth, attention heads, and feature dimensions. The implementation prioritizes efficiency through features like flash attention support while maintaining flexibility for different use cases.
The text was updated successfully, but these errors were encountered: