Skip to content

Commit

Permalink
add greenmim infer
Browse files Browse the repository at this point in the history
  • Loading branch information
xfguo committed Jan 17, 2023
1 parent c951bb6 commit 8b6b1bb
Show file tree
Hide file tree
Showing 11 changed files with 1,265 additions and 4 deletions.
23 changes: 23 additions & 0 deletions configs/selfsup/_base_/models/greenmim_swin-base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# model settings
img_size = 224
patch_size = 4

model = dict(
type='GreenMIM',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='GreenMIMSwinTransformer',
arch='B',
img_size=img_size,
patch_size=patch_size,
drop_path_rate=0.0,
stage_cfgs=dict(block_cfgs=dict(window_size=7))),
neck=dict(type='GreenMIMNeck', in_channels=3, encoder_stride=32, img_size=img_size, patch_size=patch_size),
head=dict(
type='GreenMIMHead',
patch_size=patch_size,
norm_pix_loss=False,
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
_base_ = [
'../_base_/models/greenmim_swin-base.py',
'../_base_/datasets/imagenet_mae.py',
'../_base_/schedules/adamw_coslr-200e_in1k.py',
'../_base_/default_runtime.py',
]

# dataset 16 GPUs x 128
train_dataloader = dict(batch_size=128, num_workers=16)

# optimizer wrapper
optimizer = dict(
type='AdamW', lr=2e-4 * 2048 / 512, betas=(0.9, 0.999), eps=1e-8)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=5.0),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.)
}))

# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-6 / 2e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=90,
eta_min=1e-5 * 2048 / 512,
by_epoch=True,
begin=10,
end=100,
convert_to_iter_based=True)
]

# schedule
train_cfg = dict(max_epochs=100)

# runtime
default_hooks = dict(logger=dict(type='LoggerHook', interval=100))
3 changes: 2 additions & 1 deletion mmselfsup/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from .simmim import SimMIM
from .simsiam import SimSiam
from .swav import SwAV
from .greenmim import GreenMIM

__all__ = [
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA',
'MixMIM'
'MixMIM', 'GreenMIM'
]
103 changes: 103 additions & 0 deletions mmselfsup/models/algorithms/greenmim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple

import torch
from mmengine.structures import BaseDataElement

from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from .base import BaseModel

@MODELS.register_module()
class GreenMIM(BaseModel):
"""GreenMIM.
Implementation of `GreenMIM: Green Hierarchical Vision Transformer for Masked Image Modeling
<https://arxiv.org/abs/2205.13515>`_.
"""

def extract_feat(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features from neck.
Args:
inputs (List[torch.Tensor]): The input images.
Returns:
Tuple[torch.Tensor]: Neck outputs.
"""
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
self.mask = mask
return pred

def reconstruct(self,
features: torch.Tensor,
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The function is for image reconstruction.
Args:
features (torch.Tensor): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
SelfSupDataSample: The prediction from model.
"""
mean = kwargs['mean']
std = kwargs['std']
features = features * std + mean

pred = self.head.unpatchify(features)
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()

mask = self.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
3) # (N, H*W, p*p*3)
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

results = SelfSupDataSample()
results.mask = BaseDataElement(**dict(value=mask))
results.pred = BaseDataElement(**dict(value=pred))

return results

def patchify(self, imgs, patch_size):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x

def loss(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
target = self.patchify(inputs[0], self.backbone.final_patch_size)
loss = self.head(pred, target, mask)
losses = dict(loss=loss)
return losses
3 changes: 2 additions & 1 deletion mmselfsup/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from .resnet import ResNet, ResNetSobel, ResNetV1d
from .resnext import ResNeXt
from .simmim_swin import SimMIMSwinTransformer
from .greenmim import GreenMIMSwinTransformer

__all__ = [
'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT',
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT',
'MixMIMTransformerPretrain'
'MixMIMTransformerPretrain', 'GreenMIMSwinTransformer'
]
Loading

0 comments on commit 8b6b1bb

Please sign in to comment.