-
Notifications
You must be signed in to change notification settings - Fork 433
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xfguo
committed
Jan 17, 2023
1 parent
c951bb6
commit 8b6b1bb
Showing
11 changed files
with
1,265 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# model settings | ||
img_size = 224 | ||
patch_size = 4 | ||
|
||
model = dict( | ||
type='GreenMIM', | ||
data_preprocessor=dict( | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='GreenMIMSwinTransformer', | ||
arch='B', | ||
img_size=img_size, | ||
patch_size=patch_size, | ||
drop_path_rate=0.0, | ||
stage_cfgs=dict(block_cfgs=dict(window_size=7))), | ||
neck=dict(type='GreenMIMNeck', in_channels=3, encoder_stride=32, img_size=img_size, patch_size=patch_size), | ||
head=dict( | ||
type='GreenMIMHead', | ||
patch_size=patch_size, | ||
norm_pix_loss=False, | ||
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3))) |
49 changes: 49 additions & 0 deletions
49
configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
_base_ = [ | ||
'../_base_/models/greenmim_swin-base.py', | ||
'../_base_/datasets/imagenet_mae.py', | ||
'../_base_/schedules/adamw_coslr-200e_in1k.py', | ||
'../_base_/default_runtime.py', | ||
] | ||
|
||
# dataset 16 GPUs x 128 | ||
train_dataloader = dict(batch_size=128, num_workers=16) | ||
|
||
# optimizer wrapper | ||
optimizer = dict( | ||
type='AdamW', lr=2e-4 * 2048 / 512, betas=(0.9, 0.999), eps=1e-8) | ||
optim_wrapper = dict( | ||
type='AmpOptimWrapper', | ||
optimizer=optimizer, | ||
clip_grad=dict(max_norm=5.0), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'norm': dict(decay_mult=0.0), | ||
'bias': dict(decay_mult=0.0), | ||
'absolute_pos_embed': dict(decay_mult=0.), | ||
'relative_position_bias_table': dict(decay_mult=0.) | ||
})) | ||
|
||
# learning rate scheduler | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', | ||
start_factor=1e-6 / 2e-4, | ||
by_epoch=True, | ||
begin=0, | ||
end=10, | ||
convert_to_iter_based=True), | ||
dict( | ||
type='CosineAnnealingLR', | ||
T_max=90, | ||
eta_min=1e-5 * 2048 / 512, | ||
by_epoch=True, | ||
begin=10, | ||
end=100, | ||
convert_to_iter_based=True) | ||
] | ||
|
||
# schedule | ||
train_cfg = dict(max_epochs=100) | ||
|
||
# runtime | ||
default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from mmengine.structures import BaseDataElement | ||
|
||
from mmselfsup.registry import MODELS | ||
from mmselfsup.structures import SelfSupDataSample | ||
from .base import BaseModel | ||
|
||
@MODELS.register_module() | ||
class GreenMIM(BaseModel): | ||
"""GreenMIM. | ||
Implementation of `GreenMIM: Green Hierarchical Vision Transformer for Masked Image Modeling | ||
<https://arxiv.org/abs/2205.13515>`_. | ||
""" | ||
|
||
def extract_feat(self, | ||
inputs: List[torch.Tensor], | ||
data_samples: Optional[List[SelfSupDataSample]] = None, | ||
**kwarg) -> Tuple[torch.Tensor]: | ||
"""The forward function to extract features from neck. | ||
Args: | ||
inputs (List[torch.Tensor]): The input images. | ||
Returns: | ||
Tuple[torch.Tensor]: Neck outputs. | ||
""" | ||
latent, mask, ids_restore = self.backbone(inputs[0]) | ||
pred = self.neck(latent, ids_restore) | ||
self.mask = mask | ||
return pred | ||
|
||
def reconstruct(self, | ||
features: torch.Tensor, | ||
data_samples: Optional[List[SelfSupDataSample]] = None, | ||
**kwargs) -> SelfSupDataSample: | ||
"""The function is for image reconstruction. | ||
Args: | ||
features (torch.Tensor): The input images. | ||
data_samples (List[SelfSupDataSample]): All elements required | ||
during the forward function. | ||
Returns: | ||
SelfSupDataSample: The prediction from model. | ||
""" | ||
mean = kwargs['mean'] | ||
std = kwargs['std'] | ||
features = features * std + mean | ||
|
||
pred = self.head.unpatchify(features) | ||
pred = torch.einsum('nchw->nhwc', pred).detach().cpu() | ||
|
||
mask = self.mask.detach() | ||
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * | ||
3) # (N, H*W, p*p*3) | ||
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping | ||
mask = torch.einsum('nchw->nhwc', mask).detach().cpu() | ||
|
||
results = SelfSupDataSample() | ||
results.mask = BaseDataElement(**dict(value=mask)) | ||
results.pred = BaseDataElement(**dict(value=pred)) | ||
|
||
return results | ||
|
||
def patchify(self, imgs, patch_size): | ||
""" | ||
imgs: (N, 3, H, W) | ||
x: (N, L, patch_size**2 *3) | ||
""" | ||
p = patch_size | ||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 | ||
|
||
h = w = imgs.shape[2] // p | ||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) | ||
x = torch.einsum('nchpwq->nhwpqc', x) | ||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) | ||
return x | ||
|
||
def loss(self, inputs: List[torch.Tensor], | ||
data_samples: List[SelfSupDataSample], | ||
**kwargs) -> Dict[str, torch.Tensor]: | ||
"""The forward function in training. | ||
Args: | ||
inputs (List[torch.Tensor]): The input images. | ||
data_samples (List[SelfSupDataSample]): All elements required | ||
during the forward function. | ||
Returns: | ||
Dict[str, torch.Tensor]: A dictionary of loss components. | ||
""" | ||
# ids_restore: the same as that in original repo, which is used | ||
# to recover the original order of tokens in decoder. | ||
latent, mask, ids_restore = self.backbone(inputs[0]) | ||
pred = self.neck(latent, ids_restore) | ||
target = self.patchify(inputs[0], self.backbone.final_patch_size) | ||
loss = self.head(pred, target, mask) | ||
losses = dict(loss=loss) | ||
return losses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.