ViT-UNet is a novel hierarchical ViT-based model, applied to autoencoders via UNet-shaped architectures. Background work can be found in the folowing links:
This Autoencoder structure aims to take advantage of the computational parallelisation of self-attention mechanisms, at the same time that can handle long-term dependencies via stacking multiple encoders, combines encoding and decoding information via skip-connections and hierarchises dependencies in image representation via patch size fluctuation.
For a given input image of size (3,224,224)
, three versions of this architecture are suggested:
- Lite: Number of parameters--> 3.387.568
ViT_UNet(depth = 2,
depth_te = 1,
size_bottleneck = 2,
preprocessing = 'conv',
num_patches = 196,
patch_size = 16,
num_channels = 3,
hidden_dim = 64,
num_heads = 4,
attn_drop = .2,
proj_drop = .2,
linear_drop = 0,
dtype = torch.float32,
)
- Base: Number of parameters--> 36.613.036
ViT_UNet(depth = 2,
depth_te = 2,
size_bottleneck = 2,
preprocessing = 'conv',
num_patches = 49,
patch_size = 32,
num_channels = 3,
hidden_dim = 128,
num_heads = 8,
attn_drop = .2,
proj_drop = .2,
linear_drop = 0,
dtype = torch.float32,
)
- Large: Number of parameters--> 63.043.866
ViT_UNet(depth = 2,
depth_te = 4,
size_bottleneck = 4,
preprocessing = 'conv',
num_patches = 49,
patch_size = 32,
num_channels = 3,
hidden_dim = 128,
num_heads = 8,
attn_drop = .2,
proj_drop = .2,
linear_drop = 0,
dtype = torch.float32,
)
The following tasks are to be tested:
-
Image denoising.
- Dataset: SIDD dataset.
- Two models are outstanding in the classification, which are HINet (best model in PSNR metric) and UFormer (best model in SSIM metric).
-
Deblurring.
- Dataset: GoPro dataset.
- The top model is HINet with PSNR metric.
-
Single Image Deraining.
- Multiple datasets available: Rain110H, Rain110L,... the full list can be found here.
-
Image segmentation.
- Dataset: Pancreas Segmentation on TCIA Pancreas-CT. The metric that is used here is the Dice Score, which is the equivalent to F1 w.r.t. accuracy in image segmentation, corresponding the latter to Jaccard index (IoU). A softer version of this index can be explored here.
Metrics that are required for these tasks:
- Peak signal-to-noise ratio (PSNR)
- Strictural Similarity (SSIM)
- Soft Dice Score:
def dice_loss(input:torch.Tensor,
target:torch.Tensor,
):
smooth = 1.
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth))
## Usage
To perform a training:
>python3 run_denoising.py --model_string "lite" --im_size "224"