This repository contains the code of our paper:
A Data-Based Perspective on Transfer Learning
Saachi Jain*, Hadi Salman*, Alaa Khaddaj*, Eric Wong, Sung Min Park, Aleksander Madry
Paper - Blog post
@article{jain2022data,
title={A Data-Based Perspective on Transfer Learning},
author={Jain, Saachi and Salman, Hadi and Khaddaj, Alaa and Wong, Eric and Park, Sung Min and Madry, Aleksander},
journal={arXiv preprint arXiv:2207.05739},
year={2022}
}
The major content of our repo are:
- src/: Contains all our code for running full transfer pipeline.
- configs/: Contains the config files that training codes expect. These config files contain the hyperparams for each transfer tasks.
- analysis/: Contains code for all the analysis we do in our paper.
Our code relies on the FFCV Library. To install this library along with other dependencies including PyTorch, follow the instructions below.
conda create -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge
conda activate ffcv
pip install ffcv
To train an ImageNet model and transfer it to all the datasets we consider in the paper, simply run:
python src/train_imagenet_class_subset.py \
--config-file configs/base_config.yaml \
--training.data_root $PATH_TO_DATASETS \
--out.output_pkl_dir $OUTDIR
where $OUTDIR
is the output directory of your choice, and $PATH_TO_DATASETS
is the path where the datasets exists (see below).
The config file configs/base_config.yaml
contains all the hyperparameters needed for this experiment. For example, you can specify which downstream tasks you want to transfer to, or how many Imagenet class to train on the source model.
Use analysis/data_compressors/2_20_compressor.py
to compress model results into a summary file. Then use analysis/compute_influences.py
to compute the influences. In a notebook, simply run the following code:
sf = <SUMMARY FILE FOLDER>
ds = compute_influences.SummaryFileDataSet(sf, dataset, INFLUENCE_KEY, keyword)
dl = torch.utils.data.DataLoader(ds, batch_size=1024, shuffle=False, drop_last=False)
infl = compute_influences.batch_calculate_influence(dl, len(val_labels), 1000, div=True)
Once influences have been computed, we can now run counterfactual experiments by removing top or bottom influencing classes from the source dataset (ImageNet), and then applying transfer learning again. This can be done by running:
python src/counterfactuals_main.py\
--config-file configs/base_config.yaml\
--training.transfer_task ${TASK}\
--out.output_pkl_dir ${OUT_DIR}\
--counterfactual.cf_target_dataset ${DATASET}\
--counterfactual.cf_infl_order_file ${INFL_ORDER_FILE} \
--data.num_classes -1 \
--counterfactual.cf_order TOP \
--counterfactual.cf_num_classes_min ${MIN_STEPS} \
--counterfactual.cf_num_classes_max ${MAX_STEPS} \
--counterfactual.cf_num_classes_step ${STEP_SIZE} \
--counterfactual.cf_type CLASS
-
aircraft (Download)
-
birds (Download)
-
caltech101 (Download)
-
caltech256 (Download)
-
cifar10 (Automatically downloaded when you run the code)
-
cifar100 (Automatically downloaded when you run the code)
-
flowers (Download)
-
food (Download)
-
pets (Download)
-
stanford_cars (Download)
-
SUN397 (Download)
We have created an FFCV version of each of these datasets to enable super fast training. We will make these datasets available soon!
Coming soon!
Coming soon!
Coming soon!