Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Greenmim Inference #677

Open
wants to merge 18 commits into
base: dev-1.x
Choose a base branch
from
114 changes: 114 additions & 0 deletions projects/greenmim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# GreenMIM Pre-training Model

- [GreenMIM Pre-training Model](#maskfeat-pre-training-with-video)
- [Description](#description)
- [Usage](#usage)
- [Setup Environment](#setup-environment)
- [Data Preparation](#data-preparation)
- [Pre-training Commands](#pre-training-commands)
- [On Local Single GPU](#on-local-single-gpu)
- [On Multiple GPUs](#on-multiple-gpus)
- [On Multiple GPUs with Slurm](#on-multiple-gpus-with-slurm)
- [Citation](#citation)
- [Checklist](#checklist)

## Description

<!-- Share any information you would like others to know. For example:
Author: @xxx.
This is an implementation of \[XXX\]. -->

Author: @xfguo-ucas

This is the implementation of **GreenMIM** with ImageNet.

## Usage

<!-- For a typical model, this section should contain the commands for dataset prepareation, pre-training, downstream tasks. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->

### Setup Environment

Requirements:

- MMSelfSup >= 1.0.0rc7

Please refer to [Get Started](https://mmselfsup.readthedocs.io/en/1.x/get_started.html) documentation of MMSelfSup to finish installation.

### Data Preparation

You can refer to the [documentation](https://mmclassification.readthedocs.io/en/latest/getting_started.html) in mmcls.

### Pre-training Commands

At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/greenmim/` root directory, please run command below to add it.

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

Then run the following commands to train the model:

#### On Local Single GPU

```bash
# train with mim
mim train mmselfsup ${CONFIG} --work-dir ${WORK_DIR}

# a specific command example
mim train mmselfsup configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \
--work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/

# train with scripts
python tools/train.py configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \
--work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/
```

#### On Multiple GPUs

```bash
# train with mim
# a specific command examples, 8 GPUs here
mim train mmselfsup configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \
--work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ \
--launcher pytorch --gpus 8

# train with scripts
bash tools/dist_train.sh configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py 8
```

Note:

- CONFIG: the config files under the directory `configs/`
- WORK_DIR: the working directory to save configs, logs, and checkpoints

#### On Multiple GPUs with Slurm

```bash
# train with mim
mim train mmselfsup configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \
--work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ \
--launcher slurm --gpus 16 --gpus-per-node 8 \
--partition ${PARTITION}

# train with scripts
GPUS_PER_NODE=8 GPUS=16 bash tools/slurm_train.sh ${PARTITION} greenmim \
configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \
--work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/
```

Note:

- CONFIG: the config files under the directory `configs/`
- WORK_DIR: the working directory to save configs, logs, and checkpoints
- PARTITION: the slurm partition you are using

## Citation

```bibtex
@article{huang2022green,
title={Green Hierarchical Vision Transformer for Masked Image Modeling},
author={Huang, Lang and You, Shan and Zheng, Mingkai and Wang, Fei and Qian, Chen and Yamasaki, Toshihiko},
journal={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022}
}
```
35 changes: 35 additions & 0 deletions projects/greenmim/configs/greenmim_swin-base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
custom_imports = dict(imports=['models'], allow_failed_imports=False)

# model settings
img_size = 224
patch_size = 4

model = dict(
type='GreenMIM',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='GreenMIMSwinTransformer',
arch='B',
img_size=img_size,
embed_dim=128,
num_heads=[4, 8, 16, 32],
depths=[2, 2, 18, 2],
patch_size=patch_size,
decoder_depth=1,
drop_path_rate=0.0,
stage_cfgs=dict(block_cfgs=dict(window_size=7))),
neck=dict(
type='GreenMIMNeck',
in_channels=3,
encoder_stride=32,
img_size=img_size,
patch_size=patch_size,
embed_dim=128),
head=dict(
type='GreenMIMHead',
patch_size=patch_size,
norm_pix_loss=False,
loss=dict(type='MAEReconstructionLoss')))
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
_base_ = [
'./greenmim_swin-base.py',
'../../../configs/selfsup/_base_/datasets/imagenet_mae.py',
'../../../configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py',
'../../../configs/selfsup/_base_/default_runtime.py',
]

# dataset 16 GPUs x 128
train_dataloader = dict(batch_size=128, num_workers=16)

# optimizer wrapper
optimizer = dict(
type='AdamW', lr=2e-4 * 2048 / 512, betas=(0.9, 0.999), eps=1e-8)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=5.0),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.)
}))

# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-6 / 2e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=90,
eta_min=1e-5 * 2048 / 512,
by_epoch=True,
begin=10,
end=100,
convert_to_iter_based=True)
]

# schedule
train_cfg = dict(max_epochs=100)

# runtime
default_hooks = dict(logger=dict(type='LoggerHook', interval=100))
8 changes: 8 additions & 0 deletions projects/greenmim/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .greenmim import GreenMIM
from .greenmim_backbone import GreenMIMSwinTransformer
from .greenmim_head import GreenMIMHead
from .greenmim_neck import GreenMIMNeck

__all__ = [
'GreenMIM', 'GreenMIMSwinTransformer', 'GreenMIMHead', 'GreenMIMNeck'
]
104 changes: 104 additions & 0 deletions projects/greenmim/models/greenmim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple

import torch
from mmengine.structures import BaseDataElement

from mmselfsup.models.algorithms.base import BaseModel
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample


@MODELS.register_module()
class GreenMIM(BaseModel):
"""GreenMIM.

Implementation of `GreenMIM: Green Hierarchical Vision Transformer for
Masked Image Modeling <https://arxiv.org/abs/2205.13515>`_.
"""

def extract_feat(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features from neck.

Args:
inputs (List[torch.Tensor]): The input images.

Returns:
Tuple[torch.Tensor]: Neck outputs.
"""
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
self.mask = mask
return pred

def reconstruct(self,
features: torch.Tensor,
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The function is for image reconstruction.

Args:
features (torch.Tensor): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.

Returns:
SelfSupDataSample: The prediction from model.
"""
mean = kwargs['mean']
std = kwargs['std']
features = features * std + mean

pred = self.head.unpatchify(features)
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()

mask = self.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
3) # (N, H*W, p*p*3)
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

results = SelfSupDataSample()
results.mask = BaseDataElement(**dict(value=mask))
results.pred = BaseDataElement(**dict(value=pred))

return results

def patchify(self, imgs: torch.Tensor, patch_size: int) -> torch.Tensor:
"""
imgs: (N, 3, H, W)
patch_size: int
"""
p = patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x

def loss(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.

Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.

Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
target = self.patchify(inputs[0], self.backbone.final_patch_size)
loss = self.head(pred, target, mask)
losses = dict(loss=loss)
return losses
Loading