Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restormer Implementation for MONAI: High-Resolution Image Restoration #8261

Open
phisanti opened this issue Dec 9, 2024 · 4 comments
Open

Comments

@phisanti
Copy link

phisanti commented Dec 9, 2024

Is your feature request related to a problem? Please describe.

I've noticed that MONAI currently lacks dedicated models for image denoising and restoration tasks. While MONAI provides excellent tools for medical image analysis, having specialized architectures for improving image quality would be valuable for preprocessing pipelines and enhancing low-quality medical images (microscopy, X-ray, scans...).

Describe the solution you'd like

I have implemented a well-documented version of the Restormer model (https://arxiv.org/abs/2111.09881) that could be contributed to MONAI. The implementation includes key components like:

  • Multi-DConv Head Transposed Self-Attention (MDTA) for efficient attention computation
  • Gated-DConv Feed-Forward Network (GDFN) for refined feature selection
  • Modular architecture allowing easy extension and modification
  • Support for flash attention when available
  • Comprehensive documentation of components and architecture

Describe alternatives you've considered

The implementation is already structured in a modular way with clear separation of components. I'm willing to:

  • Refactor the code to meet MONAI coding standards
  • Add appropriate type hints and docstrings
  • Include unit tests
  • Provide example notebooks demonstrating usage
  • Add benchmarks comparing performance

Additional context

The code is currently functional and tested. It supports both standard and dual-pixel tasks, with configurable parameters for network depth, attention heads, and feature dimensions. The implementation prioritizes efficiency through features like flash attention support while maintaining flexibility for different use cases.

@aylward
Copy link
Collaborator

aylward commented Dec 11, 2024

@Nic-Ma and @ericspod and @KumoLiu - this seems like an outstanding addition to MONAI - agreed?

@phisanti - if all approve, please look at our contribution guidelines. You are already doing the exact right thing by having a modular design. Whenever appropriate, please support the exploration of alternative components in this framework via that modular design and appropriate class abstractions. Please also include multiple tutorials and unit tests with your work.

Does your code currently exist in another repo that we could preliminarily review?

Thanks!

@phisanti
Copy link
Author

You can take a look at the modular implementation of the Restormer architecture here. Also copied the code below. As you can see, I maintain many of the key blocks intact and focus on expanding functionality (Flash att), and adding modularity on the enc/dec blocks. I am happy to implement extra changes if a good suggestion is made.

"""
Restormer: Efficient Transformer for High-Resolution Image Restoration
Implementation based on: https://arxiv.org/abs/2111.09881
"""

from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.layers import Norm
from einops import rearrange

class FeedForward(nn.Module):
    """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
    Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection."""
    def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool):
        super().__init__()
        hidden_features = int(dim * ffn_expansion_factor)
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, 
                               stride=1, padding=1, groups=hidden_features*2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        return self.project_out(F.gelu(x1) * x2)

class Attention(nn.Module):
    """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
    by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
    convolutions for local mixing before attention, achieving linear complexity vs quadratic
    in vanilla attention."""
    def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
        super().__init__()
        if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
            raise ValueError("Flash attention not available")
            
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.flash_attention = flash_attention
        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, 
                                   padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self._attention_fn = self._get_attention_fn()

    def _get_attention_fn(self):
        if self.flash_attention:
            return self._flash_attention
        return self._normal_attention
    def _flash_attention(self, q, k, v):
        """Flash attention implementation using scaled dot-product attention."""
        scale = float(self.temperature.mean())  
        out = F.scaled_dot_product_attention(
            q,
            k, 
            v,
            scale=scale,
            dropout_p=0.0,
            is_causal=False
        )
        return out

    def _normal_attention(self, q, k, v):
        """Attention matrix multiplication with depth-wise convolutions."""
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        return attn @ v
    def forward(self, x):
        """Forward pass for MDTA attention. 
        1. Apply depth-wise convolutions to Q, K, V
        2. Reshape Q, K, V for multi-head attention
        3. Compute attention matrix using flash or normal attention
        4. Reshape and project out attention output"""
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        
        out = self._attention_fn(q, k, v)        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out


class TransformerBlock(nn.Module):
    """Basic transformer unit combining MDTA and GDFN with skip connections.
    Unlike standard transformers that use LayerNorm, this block uses Instance Norm
    for better adaptation to image restoration tasks."""
    
    def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float,
                 bias: bool, LayerNorm_type: str, flash_attention: bool = False):
        super().__init__()
        use_bias = LayerNorm_type != 'BiasFree'        
        self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
        self.attn = Attention(dim, num_heads, bias, flash_attention)
        self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #print(f'x shape in transformer block: {x.shape}')

        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class OverlapPatchEmbed(nn.Module):
    """Initial feature extraction using overlapped convolutions.
    Unlike standard patch embeddings that use non-overlapping patches,
    this approach maintains spatial continuity through 3x3 convolutions."""
    
    def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
        super().__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, 
                             stride=1, padding=1, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

class Downsample(nn.Module):
    """Downsampling module that halves spatial dimensions while doubling channels.
    Uses PixelUnshuffle for efficient feature map manipulation."""
    
    def __init__(self, n_feat: int):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feat, n_feat//2, kernel_size=3, 
                     stride=1, padding=1, bias=False),
            nn.PixelUnshuffle(2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.body(x)

class Upsample(nn.Module):
    """Upsampling module that doubles spatial dimensions while halving channels.
    Combines convolution with PixelShuffle for efficient feature expansion."""
    
    def __init__(self, in_channels: int) -> None:
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, 
                     stride=1, padding=1, bias=False),
            nn.PixelShuffle(2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.body(x)


##---------- Restormer -----------------------
class Restormer(nn.Module):
    """Restormer: Efficient Transformer for High-Resolution Image Restoration.
    
    Implements a U-Net style architecture with transformer blocks, combining:
    - Multi-scale feature processing through progressive down/upsampling
    - Efficient attention via MDTA blocks
    - Local feature mixing through GDFN
    - Skip connections for preserving spatial details
    
    Architecture:
        - Encoder: Progressive feature downsampling with increasing channels
        - Latent: Deep feature processing at lowest resolution
        - Decoder: Progressive upsampling with skip connections
        - Refinement: Final feature enhancement
    """
    def __init__(self, 
                 inp_channels=3,
                 out_channels=3,
                 dim=48,
                 num_blocks=[1, 1, 1, 1],  
                 heads=[1, 1, 1, 1],  
                 num_refinement_blocks=4,    
                 ffn_expansion_factor=2.66,
                 bias=False,
                 LayerNorm_type='WithBias',
                 dual_pixel_task=False,
                 flash_attention=False):
        super().__init__()
        """Initialize Restormer model.
        
        Args:
            inp_channels: Number of input image channels
            out_channels: Number of output image channels
            dim: Base feature dimension
            num_blocks: Number of transformer blocks at each scale
            num_refinement_blocks: Number of final refinement blocks
            heads: Number of attention heads at each scale
            ffn_expansion_factor: Expansion factor for feed-forward network
            bias: Whether to use bias in convolutions
            LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree')
            dual_pixel_task: Enable dual-pixel specific processing
            flash_attention: Use flash attention if available
        """
        # Check input parameters
        assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
        assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
        assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0"
        
        # Initial feature extraction
        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        self.encoder_levels = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        self.decoder_levels = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.reduce_channels = nn.ModuleList()
        num_steps = len(num_blocks) - 1 
        self.num_steps = num_steps

        # Define encoder levels
        for n in range(num_steps):
            current_dim = dim * 2**n
            self.encoder_levels.append(
                nn.Sequential(*[
                    TransformerBlock(
                        dim=current_dim,
                        num_heads=heads[n],
                        ffn_expansion_factor=ffn_expansion_factor,
                        bias=bias,
                        LayerNorm_type=LayerNorm_type,
                        flash_attention=flash_attention
                    ) for _ in range(num_blocks[n])
                ])
            )
            self.downsamples.append(Downsample(current_dim))

        # Define latent space
        latent_dim = dim * 2**num_steps
        self.latent = nn.Sequential(*[
            TransformerBlock(
                dim=latent_dim,
                num_heads=heads[num_steps],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
                flash_attention=flash_attention
            ) for _ in range(num_blocks[num_steps])
        ])

        # Define decoder levels
        for n in reversed(range(num_steps)):
            current_dim = dim * 2**n
            next_dim = dim * 2**(n+1)
            self.upsamples.append(Upsample(next_dim))
            
            # Reduce channel layers to deal with skip connections
            if n != 0:
                self.reduce_channels.append(
                    nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias)
                    )
                decoder_dim = current_dim
            else:
                decoder_dim = next_dim
            
            self.decoder_levels.append(
                nn.Sequential(*[
                    TransformerBlock(
                        dim=decoder_dim,
                        num_heads=heads[n],
                        ffn_expansion_factor=ffn_expansion_factor,
                        bias=bias,
                        LayerNorm_type=LayerNorm_type,
                        flash_attention=flash_attention
                    ) for _ in range(num_blocks[n])
                ])
            )

        # Final refinement and output
        self.refinement = nn.Sequential(*[
            TransformerBlock(
                dim=decoder_dim,
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
                flash_attention=flash_attention
            ) for _ in range(num_refinement_blocks)
        ])
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
            
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        """Forward pass of Restormer.
        Processes input through encoder-decoder architecture with skip connections.
        Args:
            inp_img: Input image tensor of shape (B, C, H, W)
            
        Returns:
            Restored image tensor of shape (B, C, H, W)
        """
        assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step"

        # Patch embedding
        x = self.patch_embed(x)
        skip_connections = []

        # Encoding path
        for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
            x = encoder(x)
            skip_connections.append(x)
            x = downsample(x)

        # Latent space
        x = self.latent(x)
        
        # Decoding path
        for idx in range(len(self.decoder_levels)):
            x = self.upsamples[idx](x)           
            x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
            if idx < len(self.decoder_levels) - 1:
                x = self.reduce_channels[idx](x)                
            x = self.decoder_levels[idx](x)
                
        # Final refinement
        x = self.refinement(x)

        if self.dual_pixel_task:
            x = x + self.skip_conv(skip_connections[0])
            x = self.output(x)
        else:
            x = self.output(x)

        return x

if __name__ == "__main__":
    flash_att = True
    test_model = Restormer(
        inp_channels=2,
        out_channels=2,
        dim=16,
        num_blocks=[1,1,1,1],
        heads=[1,1,1,1],
        num_refinement_blocks=2,
        ffn_expansion_factor=1.5,
        bias=False,
        LayerNorm_type='WithBias',
        dual_pixel_task=True,
        flash_attention=flash_att
    )
    print(f'flash attention set to {flash_att}')
    input_tensor = torch.randn(8, 2, 256, 256)
    print(f"Input shape: {input_tensor.shape}")
    output = test_model(input_tensor)
    print(f"Output shape: {output.shape}")

    print(f'printing final model')
    from torchsummary import summary
    
    summary(test_model, input_size=input_tensor)
    
    ```

@phisanti
Copy link
Author

@aylward @Nic-Ma @ericspod and @KumoLiu, if you all agree and there is no comments on extra modules to be added, I will implement the class as it is. For that, I will:

  1. Fork and create branch '8261-restormer-implementation'
  2. Place architecture in MONAI/monai/networks/nets folder
  3. Add extensive documentation (docstring + docs) following UNet class style as template
  4. Write unit tests following existing test patterns
  5. Create tutorial notebook with example dataset for the Project-MONAI/tutorials
  6. Submit PRs for both code and tutorial

What aspects of this approach would you modify to fully align with MONAI's contribution standards?

@KumoLiu
Copy link
Contributor

KumoLiu commented Dec 19, 2024

Hi @phisanti, thank you for sharing the comprehensive plan!

I’d recommend dividing the implementation into several PRs to simplify the review process. Additionally, I highly suggest checking if there are existing blocks in MONAI that can be reused in your network, such as upsample, downsample, attention mechanisms, etc.

https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/downsample.py
https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py
https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/selfattention.py
https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/spatialattention.py

Also, consider using Convolution, which could make your network support both 2D and 3D implementations seamlessly.

class Convolution(nn.Sequential):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants