Official pytorch implementation of "Distribution-Aware Prompt Tuning for Vision-Language Models" (ICCV 2023).
git clone https://github.com/mlvlab/DAPT.git
cd DAPT
Follow DATASET.md to install the datasets.
Before creating the environment, you should modify appropriate conda path in env.yaml
conda env create —-file env.yaml
conda activate dapt
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
Setup Dassl.pytorch package
cd Dassl.pytorch
python setup.py develop
cd ..
Modify the data path $DATA
in main.sh
, gen_prototype.sh
, and eval.sh
to match the path to the dataset you downloaded.
When the dataset is ready, you can generate the prototype as follows.
bash scripts/gen_prototype.sh [gpu_id]
Below is an example of Caltech101 for each shot.
Note that for ImageNet, we use configs/trainers/DAPT/vit_b16_ep50.yaml
for all settings following CoOp.
# 1shot
bash scripts/main.sh caltech101 1 [gpu_id]
# 2shots
bash scripts/main.sh caltech101 2 [gpu_id]
# 4shots
bash scripts/main.sh caltech101 4 [gpu_id]
# 8shots
bash scripts/main.sh caltech101 8 [gpu_id]
# 16shots
bash scripts/main.sh caltech101 16 [gpu_id]
Before domain generalization, you should completed few-shot image classification on ImageNet.
After the few-shot image classification experiment on ImageNet is finished, you can load the model learned on ImageNet using --eval-only
command to conduct domain generalization on imagenetv2
, imagenet-sketch
, imagenet-a
, and imagenet-r
.
bash scripts/eval.sh [gpu_id]
This repository is built upon Dassl.pytorch, CoOp, and VPT. We thank the authors for their code.
If you use this code in your research, please kindly cite the following paper:
@InProceedings{Cho_2023_ICCV,
author = {Cho, Eulrang and Kim, Jooyeon and Kim, Hyunwoo J},
title = {Distribution-Aware Prompt Tuning for Vision-Language Models},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2023},
pages = {22004-22013}
}
Licensed under MIT License
Copyright (c) 2023 MLV Lab (Machine Learning and Vision Lab at Korea University)