diff --git a/task_vectors/README.md b/task_vectors/README.md new file mode 100644 index 0000000..9c7bb44 --- /dev/null +++ b/task_vectors/README.md @@ -0,0 +1,136 @@ +# Editing Models with Task Arithmetic + +This repository contains code for the ICLR 2023 paper [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089), by Gabriel Ilharco, Marco Tulio Ribeiro, Mitchell Wortsman, Suchin Gururangan, Ludwig Schmidt, Hannaneh Hajishirzi and Ali Farhadi. + +### Abstract +*Changing how pre-trained models behave---e.g., improving their performance on a downstream task or mitigating biases learned during pre-training---is a common practice when developing machine learning systems. In this work, we propose a new paradigm for steering the behavior of neural networks, centered around task vectors. A task vector specifies a direction in the weight space of a pre-trained model, such that movement in that direction improves performance on the task. We build task vectors by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning on a task. We show that these task vectors can be modified and combined together through arithmetic operations such as negation and addition, and the behavior of the resulting model is steered accordingly. Negating a task vector decreases performance on the target task, with little change in model behavior on control tasks. Moreover, adding task vectors together can improve performance on multiple tasks at once. Finally, when tasks are linked by an analogy relationship of the form ``A is to B as C is to D", combining task vectors from three of the tasks can improve performance on the fourth, even when no data from the fourth task is used for training. Overall, our experiments with several models, modalities and tasks show that task arithmetic is a simple, efficient and effective way of editing models.* + + +### Summary figure + +

+scatter +

+ +An illustration of task vectors and the arithmetic operations we study for editing models. (a) A task vector is obtained by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning. (b) Negating a task vector degrades performance on the task, without substantial changes in control tasks. (c) Adding task vectors together improves the performance of the pre-trained model on the tasks under consideration. (d) When tasks form an analogy relationship such as supervised and unsupervised learning on two different data sources, it is possible to improve performance on a supervised target task using only vectors from the remaining three combinations of objectives and datasets. + +## Code + +### Install dependencies + +```bash +conda env create +conda activate task-vectors +``` + + +### Add directory to PYTHONPATH: + +```bash +cd task_vectors +export PYTHONPATH="$PYTHONPATH:$PWD" +``` + +### Using task vectors + +The task vector logic can be found at [src/task_vectors.py](src/task_vectors.py). + +To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint: + +```python +from task_vectors import TaskVector +task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) +``` + +Once created, task vectors can be modified and combined through arithmetic operations! For instance, to negate a task vector, simply use the ```-``` operator: + +```python +# Negating a task vector +new_task_vector = -task_vector +``` + +To add task vectors, you can use the ```+``` operator, or ```sum```: + +```python +# Adding two task vectors +new_task_vector = task_vector_A + task_vector_B +# Adding multiple task vectors +new_task_vector = sum(list_of_task_vectors) +``` + +Analogies can be done as simply as: + +```python +# Task analogies +new_task_vector = task_vector_C + task_vector_B - task_vector_A +``` + +### Checkpoints + +Checkpoints for CLIP ViT-B/32, ViT-B/16 and ViT-L/14 are available on he link below, including fine-tuned checkpoints on eight downstream tasks: Stanford Cars, DTD, EuroSAT, GTSRB, MNIST, RESISC45, SUN397 and SVHN. + +[Download here](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw?usp=share_link) + +### Examples + +Below is an example of negating a task vector from MNIST, then evaluating on MNIST and on ImageNet: + +```python +import torch +from task_vectors import TaskVector +from eval import eval_single_dataset +from args import parse_arguments + +# Config +dataset = 'MNIST' +model = 'ViT-L-14' +args = parse_arguments() +args.data_location = '/path/to/data' +args.model = model +args.save = f'checkpoints/{model}' +pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt' +finetuned_checkpoint = f'checkpoints/{model}/{dataset}/finetuned.pt' + + +# Create the task vector +task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) +# Negate the task vector +neg_task_vector = -task_vector +# Apply the task vector +image_encoder = neg_task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.5) +# Evaluate +eval_single_dataset(image_encoder, dataset, args) +eval_single_dataset(image_encoder, 'ImageNet', args) +``` + +You can also find an example of adding task vectors together below, using the MNIST and RESISC45 datasets: + + +```python +import torch +from task_vectors import TaskVector +from eval import eval_single_dataset +from args import parse_arguments + +# Config +datasets = ['MNIST', 'RESISC45'] +model = 'ViT-L-14' +args = parse_arguments() +args.data_location = '/path/to/data' +args.model = model +args.save = f'checkpoints/{model}' +pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt' + +# Create the task vectors +task_vectors = [ + TaskVector(pretrained_checkpoint, f'checkpoints/{model}/{dataset}/finetuned.pt') + for dataset in datasets +] +# Sum the task vectors +task_vector_sum = sum(task_vectors) +# Apply the resulting task vector +image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=0.8) +# Evaluate +for dataset in datasets: + eval_single_dataset(image_encoder, dataset, args) +``` diff --git a/task_vectors/download_pretrained_model.py b/task_vectors/download_pretrained_model.py new file mode 100644 index 0000000..b102499 --- /dev/null +++ b/task_vectors/download_pretrained_model.py @@ -0,0 +1,49 @@ +import transformers +from datasets import load_dataset +import sys +import os +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from peft import get_peft_model, LoraConfig, IA3Config, TaskType + +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('--model_name', type=str, required=True) +parser.add_argument('--pem', type=str, default=None) + +args = parser.parse_args() + + +model = AutoModelForCausalLM.from_pretrained( + args.model_name, + device_map='auto', +) + +if 'Llama' in args.model_name and args.pem is not None: + if args.pem == 'lora_adapter': + llama_peft_config = LoraConfig( + task_type="CAUSAL_LM", + r=8, + lora_dropout=0.01, + ) + elif args.pem == 'ia3_adapter': + llama_peft_config = IA3Config( + peft_type="IA3", + task_type="CAUSAL_LM", + ) + else: + raise ValueError("Invalid PEM type for Llama model.") + model = get_peft_model(model, llama_peft_config) + model.print_trainable_parameters() + +tokenizer = AutoTokenizer.from_pretrained(args.model_name) +if 'Llama' in args.model_name: + model.config.pad_token_id = model.config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + +model.save_pretrained("./outputs/pretrained/" + args.model_name) +tokenizer.save_pretrained("./outputs/pretrained/" + args.model_name) diff --git a/task_vectors/environment.yml b/task_vectors/environment.yml new file mode 100644 index 0000000..d123e98 --- /dev/null +++ b/task_vectors/environment.yml @@ -0,0 +1,108 @@ +name: task-vectors +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotlipy=0.7.0=py310h5764c6d_1004 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2022.9.24=ha878542_0 + - cffi=1.15.1=py310h74dc2b5_0 + - cryptography=37.0.2=py310h597c629_0 + - cudatoolkit=11.6.0=hecad31d_10 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.10.4=h0708190_1 + - giflib=5.2.1=h36c2ea0_2 + - gmp=6.2.1=h58526e2_0 + - gnutls=3.6.13=h85f3911_1 + - intel-openmp=2021.4.0=h06a4308_3561 + - jbig=2.1=h7f98852_2003 + - jpeg=9e=h166bdaf_1 + - lame=3.100=h7f98852_1001 + - lcms2=2.12=hddcbb42_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=2.2.1=h9c3ff4c_0 + - libdeflate=1.7=h7f98852_5 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.17=h166bdaf_0 + - libpng=1.6.37=h21135ba_2 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtiff=4.3.0=hf544144_1 + - libuuid=1.0.3=h7f8727e_2 + - libwebp=1.2.2=h3452ae3_0 + - libwebp-base=1.2.2=h7f98852_1 + - lz4-c=1.9.3=h9c3ff4c_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py310ha2c4b55_0 + - mkl_fft=1.3.1=py310h2b4bcf5_1 + - mkl_random=1.2.2=py310h00e6091_0 + - ncurses=6.3=h5eee18b_3 + - nettle=3.6=he412f7d_0 + - numpy-base=1.23.1=py310hcba007f_0 + - openh264=2.1.1=h780b84a_0 + - openssl=1.1.1o=h166bdaf_0 + - pip=22.2.2=py310h06a4308_0 + - python=3.10.4=h12debd9_0 + - python_abi=3.10=2_cp310 + - pytorch=1.12.1=py3.10_cuda11.6_cudnn8.3.2_0 + - pytorch-mutex=1.0=cuda + - readline=8.1.2=h7f8727e_1 + - setuptools=63.4.1=py310h06a4308_0 + - sqlite=3.39.3=h5082296_0 + - tk=8.6.12=h1ccaba5_0 + - typing_extensions=4.3.0=pyha770c72_0 + - tzdata=2022c=h04d1e81_0 + - xz=5.2.6=h5eee18b_0 + - zlib=1.2.12=h5eee18b_3 + - zstd=1.5.0=ha95c52a_0 + - pip: + - certifi==2022.9.24 + - charset-normalizer==2.1.1 + - contourpy==1.0.5 + - cvxpy==1.2.2 + - cycler==0.11.0 + - ecos==2.0.10 + - filelock==3.8.0 + - fonttools==4.37.4 + - ftfy==6.1.1 + - huggingface-hub==0.10.0 + - idna==3.4 + - kiwisolver==1.4.4 + - matplotlib==3.6.0 + - numpy==1.23.3 + - open-clip-torch==2.0.2 + - osqp==0.6.2.post5 + - packaging==21.3 + - pandas==1.5.0 + - patsy==0.5.2 + - pillow==9.2.0 + - plotly==5.11.0 + - pycparser==2.21 + - pyopenssl==22.0.0 + - pyparsing==3.0.9 + - pysocks==1.7.1 + - python-dateutil==2.8.2 + - pytz==2022.4 + - pyyaml==6.0 + - qdldl==0.1.5.post2 + - regex==2022.9.13 + - requests==2.28.1 + - scipy==1.9.1 + - scs==3.2.2 + - seaborn==0.12.1 + - six==1.16.0 + - statsmodels==0.13.2 + - tenacity==8.1.0 + - torch==1.12.1 + - torchaudio==0.12.1+cu116 + - torchvision==0.13.1 + - tqdm==4.64.1 + - typing-extensions==4.3.0 + - urllib3==1.26.12 + - wcwidth==0.2.5 + - wheel==0.37.1 diff --git a/task_vectors/img/task_vectors.png b/task_vectors/img/task_vectors.png new file mode 100644 index 0000000..37c53ba Binary files /dev/null and b/task_vectors/img/task_vectors.png differ diff --git a/task_vectors/negation.py b/task_vectors/negation.py new file mode 100644 index 0000000..66c2d72 --- /dev/null +++ b/task_vectors/negation.py @@ -0,0 +1,72 @@ + +import argparse +import logging +import os +import torch +import argparse +import os +from peft import PeftConfig, PeftModel + + +# import numpy as np +# import torch +# import random +# import pandas as pd +# import pdb +from transformers import AutoModelForCausalLM + +def adapter_negation( + model, + adapter_config, + scale + ): + + + state_dict = model.state_dict() + merged_keys = [mk for mk in state_dict.keys() if ("lora_A" in mk)] + print("lora_A:",len(merged_keys)) + merged_keys = [mk for mk in state_dict.keys() if ("lora" in mk)] + print("lora:",len(merged_keys)) + # breakpoint() + if adapter_config=="lora_adapter": + # neg_dict = {k:-v for k,v in state_dict.items() if "lora_A" in k} + neg_dict = {k:-1*scale*v for k,v in state_dict.items() if "lora_A" in k} + + # ia3 (h+l*delta_h)-(h+delta_h)=(l-1)*delta_h h+delta_h-(l-1)*delta_h=h+(2-l)*delta_h + elif adapter_config=="ia3_adapter": + # neg_dict = {k:(torch.ones(v.shape)*2-v) for k,v in state_dict.items() if "lora" in k} + neg_dict = {k:(torch.ones(v.shape)*(1+scale)-scale*v) for k,v in state_dict.items() if "lora" in k} + else: + raise ValueError(f"adapter_config {adapter_config} not supported") + + state_dict.update(neg_dict) + model.load_state_dict(state_dict) + # model.set_active_adapters(["civil_comments"]) + # model.save_all_adapters(save_path) + + return model + + + +if __name__ == "__main__": + # Create the parser + parser = argparse.ArgumentParser() + + parser.add_argument('--pretrained_model_path', type=str, required=True) + parser.add_argument('--finetuned_model_path', type=str, required=True) + parser.add_argument('--scaling_coef', type=float, required=True) + parser.add_argument('--model_save_path', type = str, required=True) + args = parser.parse_args() + + pretrained_model= AutoModelForCausalLM.from_pretrained(args.pretrained_model_path) + finetuned_model= PeftModel.from_pretrained(pretrained_model, args.finetuned_model_path) + adapter_config=args.finetuned_model_path.split("-")[-1]#[:-1] + print(args.finetuned_model_path.split("-")[-1][:-1]) + unbiased_model = adapter_negation(finetuned_model,adapter_config,args.scaling_coef) + print("unbiased_model created") + unbiased_model = unbiased_model.merge_and_unload() + print("unbiased_model merged") + unbiased_model.save_pretrained(args.model_save_path) + print("unbiased_model saved") + + diff --git a/task_vectors/src/args.py b/task_vectors/src/args.py new file mode 100644 index 0000000..539f21a --- /dev/null +++ b/task_vectors/src/args.py @@ -0,0 +1,106 @@ +import os +import argparse + +import torch + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-location", + type=str, + default=os.path.expanduser('~/data'), + help="The root directory for the datasets.", + ) + parser.add_argument( + "--eval-datasets", + default=None, + type=lambda x: x.split(","), + help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. " + ) + parser.add_argument( + "--train-dataset", + default=None, + type=lambda x: x.split(","), + help="Which dataset(s) to patch on.", + ) + parser.add_argument( + "--exp_name", + type=str, + default=None, + help="Name of the experiment, for organization purposes only." + ) + parser.add_argument( + "--results-db", + type=str, + default=None, + help="Where to store the results, else does not store", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="The type of model (e.g. RN50, ViT-B-32).", + ) + parser.add_argument( + "--batch-size", + type=int, + default=128, + ) + parser.add_argument( + "--lr", + type=float, + default=0.001, + help="Learning rate." + ) + parser.add_argument( + "--wd", + type=float, + default=0.1, + help="Weight decay" + ) + parser.add_argument( + "--ls", + type=float, + default=0.0, + help="Label smoothing." + ) + parser.add_argument( + "--warmup_length", + type=int, + default=500, + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + ) + parser.add_argument( + "--load", + type=lambda x: x.split(","), + default=None, + help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", + ) + parser.add_argument( + "--save", + type=str, + default=None, + help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Directory for caching features and encoder", + ) + parser.add_argument( + "--openclip-cachedir", + type=str, + default='/gscratch/efml/gamaga/.cache/open_clip', + help='Directory for caching models from OpenCLIP' + ) + parsed_args = parser.parse_args() + parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" + + if parsed_args.load is not None and len(parsed_args.load) == 1: + parsed_args.load = parsed_args.load[0] + return parsed_args diff --git a/task_vectors/src/datasets/cars.py b/task_vectors/src/datasets/cars.py new file mode 100644 index 0000000..3e10ab2 --- /dev/null +++ b/task_vectors/src/datasets/cars.py @@ -0,0 +1,155 @@ +import os +import torch +import torchvision.datasets as datasets + + +import pathlib +from typing import Callable, Optional, Any, Tuple + +from PIL import Image + +from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg +from torchvision.datasets.vision import VisionDataset + + +class PytorchStanfordCars(VisionDataset): + """`Stanford Cars `_ Dataset + + The Cars dataset contains 16,185 images of 196 classes of cars. The data is + split into 8,144 training images and 8,041 testing images, where each class + has been split roughly in a 50-50 split + + .. note:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (string): Root directory of dataset + split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again.""" + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + try: + import scipy.io as sio + except ImportError: + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") + + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "stanford_cars" + devkit = self._base_folder / "devkit" + + if self._split == "train": + self._annotations_mat_path = devkit / "cars_train_annos.mat" + self._images_base_path = self._base_folder / "cars_train" + else: + self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" + self._images_base_path = self._base_folder / "cars_test" + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._samples = [ + ( + str(self._images_base_path / annotation["fname"]), + annotation["class"] - 1, # Original target mapping starts from 1, hence -1 + ) + for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] + ] + + self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + """Returns pil_image and class_id for given index""" + image_path, target = self._samples[idx] + pil_image = Image.open(image_path).convert("RGB") + + if self.transform is not None: + pil_image = self.transform(pil_image) + if self.target_transform is not None: + target = self.target_transform(target) + return pil_image, target + + + def download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", + download_root=str(self._base_folder), + md5="c3b158d763b6e2245038c8ad08e45376", + ) + if self._split == "train": + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", + download_root=str(self._base_folder), + md5="065e5b463ae28d29e77c1b4b166cfe61", + ) + else: + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", + download_root=str(self._base_folder), + md5="4ce7ebf6a94d07f1952d94dd34c4d501", + ) + download_url( + url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", + root=str(self._base_folder), + md5="b0a2b23655a3edd16d84508592a98d10", + ) + + def _check_exists(self) -> bool: + if not (self._base_folder / "devkit").is_dir(): + return False + + return self._annotations_mat_path.exists() and self._images_base_path.is_dir() + + +class Cars: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + # Data loading code + + self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=True) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=True) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i].replace( + '_', ' ') for i in range(len(idx_to_class))] diff --git a/task_vectors/src/datasets/cifar10.py b/task_vectors/src/datasets/cifar10.py new file mode 100644 index 0000000..096913b --- /dev/null +++ b/task_vectors/src/datasets/cifar10.py @@ -0,0 +1,56 @@ +import os +import PIL +import torch +import numpy as np +import torchvision +from torchvision import transforms +from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 +from torchvision.datasets import VisionDataset + +cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] + +class CIFAR10: + def __init__(self, preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + + self.train_dataset = PyTorchCIFAR10( + root=location, download=True, train=True, transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + + self.test_dataset = PyTorchCIFAR10( + root=location, download=True, train=False, transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + self.classnames = self.test_dataset.classes + +def convert(x): + if isinstance(x, np.ndarray): + return torchvision.transforms.functional.to_pil_image(x) + return x + +class BasicVisionDataset(VisionDataset): + def __init__(self, images, targets, transform=None, target_transform=None): + if transform is not None: + transform.transforms.insert(0, convert) + super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) + assert len(images) == len(targets) + + self.images = images + self.targets = targets + + def __getitem__(self, index): + return self.transform(self.images[index]), self.targets[index] + + def __len__(self): + return len(self.targets) diff --git a/task_vectors/src/datasets/cifar100.py b/task_vectors/src/datasets/cifar100.py new file mode 100644 index 0000000..c7b3bb4 --- /dev/null +++ b/task_vectors/src/datasets/cifar100.py @@ -0,0 +1,30 @@ +import os +import torch +from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 + +class CIFAR100: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + self.train_dataset = PyTorchCIFAR100( + root=location, download=True, train=True, transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, batch_size=batch_size, num_workers=num_workers + ) + + self.test_dataset = PyTorchCIFAR100( + root=location, download=True, train=False, transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + self.classnames = self.test_dataset.classes + + diff --git a/task_vectors/src/datasets/common.py b/task_vectors/src/datasets/common.py new file mode 100644 index 0000000..7cb6cce --- /dev/null +++ b/task_vectors/src/datasets/common.py @@ -0,0 +1,139 @@ +import os +import torch +import json +import glob +import collections +import random + +import numpy as np + +from tqdm import tqdm + +import torchvision.datasets as datasets +from torch.utils.data import Dataset, DataLoader, Sampler + + +class SubsetSampler(Sampler): + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (i for i in self.indices) + + def __len__(self): + return len(self.indices) + +class ImageFolderWithPaths(datasets.ImageFolder): + def __init__(self, path, transform, flip_label_prob=0.0): + super().__init__(path, transform) + self.flip_label_prob = flip_label_prob + if self.flip_label_prob > 0: + print(f'Flipping labels with probability {self.flip_label_prob}') + num_classes = len(self.classes) + for i in range(len(self.samples)): + if random.random() < self.flip_label_prob: + new_label = random.randint(0, num_classes-1) + self.samples[i] = ( + self.samples[i][0], + new_label + ) + + def __getitem__(self, index): + image, label = super(ImageFolderWithPaths, self).__getitem__(index) + return { + 'images': image, + 'labels': label, + 'image_paths': self.samples[index][0] + } + + +def maybe_dictionarize(batch): + if isinstance(batch, dict): + return batch + + if len(batch) == 2: + batch = {'images': batch[0], 'labels': batch[1]} + elif len(batch) == 3: + batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} + else: + raise ValueError(f'Unexpected number of elements: {len(batch)}') + + return batch + + +def get_features_helper(image_encoder, dataloader, device): + all_data = collections.defaultdict(list) + + image_encoder = image_encoder.to(device) + image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) + image_encoder.eval() + + with torch.no_grad(): + for batch in tqdm(dataloader): + batch = maybe_dictionarize(batch) + features = image_encoder(batch['images'].cuda()) + + all_data['features'].append(features.cpu()) + + for key, val in batch.items(): + if key == 'images': + continue + if hasattr(val, 'cpu'): + val = val.cpu() + all_data[key].append(val) + else: + all_data[key].extend(val) + + for key, val in all_data.items(): + if torch.is_tensor(val[0]): + all_data[key] = torch.cat(val).numpy() + + return all_data + + +def get_features(is_train, image_encoder, dataset, device): + split = 'train' if is_train else 'val' + dname = type(dataset).__name__ + if image_encoder.cache_dir is not None: + cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' + cached_files = glob.glob(f'{cache_dir}/*') + if image_encoder.cache_dir is not None and len(cached_files) > 0: + print(f'Getting features from {cache_dir}') + data = {} + for cached_file in cached_files: + name = os.path.splitext(os.path.basename(cached_file))[0] + data[name] = torch.load(cached_file) + else: + print(f'Did not find cached features at {cache_dir}. Building from scratch.') + loader = dataset.train_loader if is_train else dataset.test_loader + data = get_features_helper(image_encoder, loader, device) + if image_encoder.cache_dir is None: + print('Not caching because no cache directory was passed.') + else: + os.makedirs(cache_dir, exist_ok=True) + print(f'Caching data at {cache_dir}') + for name, val in data.items(): + torch.save(val, f'{cache_dir}/{name}.pt') + return data + + +class FeatureDataset(Dataset): + def __init__(self, is_train, image_encoder, dataset, device): + self.data = get_features(is_train, image_encoder, dataset, device) + + def __len__(self): + return len(self.data['features']) + + def __getitem__(self, idx): + data = {k: v[idx] for k, v in self.data.items()} + data['features'] = torch.from_numpy(data['features']).float() + return data + + +def get_dataloader(dataset, is_train, args, image_encoder=None): + if image_encoder is not None: + feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) + dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) + else: + dataloader = dataset.train_loader if is_train else dataset.test_loader + return dataloader \ No newline at end of file diff --git a/task_vectors/src/datasets/dtd.py b/task_vectors/src/datasets/dtd.py new file mode 100644 index 0000000..79fb3c3 --- /dev/null +++ b/task_vectors/src/datasets/dtd.py @@ -0,0 +1,34 @@ +import os +import torch +import torchvision.datasets as datasets + + +class DTD: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + # Data loading code + traindir = os.path.join(location, 'dtd', 'train') + valdir = os.path.join(location, 'dtd', 'val') + + self.train_dataset = datasets.ImageFolder( + traindir, transform=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i].replace( + '_', ' ') for i in range(len(idx_to_class))] \ No newline at end of file diff --git a/task_vectors/src/datasets/eurosat.py b/task_vectors/src/datasets/eurosat.py new file mode 100644 index 0000000..8ef20d9 --- /dev/null +++ b/task_vectors/src/datasets/eurosat.py @@ -0,0 +1,75 @@ +import os +import torch +import torchvision.datasets as datasets +import re + +def pretify_classname(classname): + l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) + l = [i.lower() for i in l] + out = ' '.join(l) + if out.endswith('al'): + return out + ' area' + return out + +class EuroSATBase: + def __init__(self, + preprocess, + test_split, + location='~/datasets', + batch_size=32, + num_workers=16): + # Data loading code + traindir = os.path.join(location, 'EuroSAT_splits', 'train') + testdir = os.path.join(location, 'EuroSAT_splits', test_split) + + + self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] + self.classnames = [pretify_classname(c) for c in self.classnames] + ours_to_open_ai = { + 'annual crop': 'annual crop land', + 'forest': 'forest', + 'herbaceous vegetation': 'brushland or shrubland', + 'highway': 'highway or road', + 'industrial area': 'industrial buildings or commercial buildings', + 'pasture': 'pasture land', + 'permanent crop': 'permanent crop land', + 'residential area': 'residential buildings or homes or apartments', + 'river': 'river', + 'sea lake': 'lake or sea', + } + for i in range(len(self.classnames)): + self.classnames[i] = ours_to_open_ai[self.classnames[i]] + + +class EuroSAT(EuroSATBase): + def __init__(self, + preprocess, + location='~/datasets', + batch_size=32, + num_workers=16): + super().__init__(preprocess, 'test', location, batch_size, num_workers) + + +class EuroSATVal(EuroSATBase): + def __init__(self, + preprocess, + location='~/datasets', + batch_size=32, + num_workers=16): + super().__init__(preprocess, 'val', location, batch_size, num_workers) diff --git a/task_vectors/src/datasets/gtsrb.py b/task_vectors/src/datasets/gtsrb.py new file mode 100644 index 0000000..b089e12 --- /dev/null +++ b/task_vectors/src/datasets/gtsrb.py @@ -0,0 +1,205 @@ +import csv +import os +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import PIL +import torch +from torchvision.datasets.folder import make_dataset +from torchvision.datasets.utils import (download_and_extract_archive, + verify_str_arg) +from torchvision.datasets.vision import VisionDataset + +def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: + """Finds the class folders in a dataset. + + See :class:`DatasetFolder` for details. + """ + classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) + if not classes: + raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + +class PyTorchGTSRB(VisionDataset): + """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. + + Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "gtsrb" + self._target_folder = ( + self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") + ) + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + if self._split == "train": + _, class_to_idx = find_classes(str(self._target_folder)) + samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) + else: + with open(self._base_folder / "GT-final_test.csv") as csv_file: + samples = [ + (str(self._target_folder / row["Filename"]), int(row["ClassId"])) + for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) + ] + + self._samples = samples + self.transform = transform + self.target_transform = target_transform + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + + path, target = self._samples[index] + sample = PIL.Image.open(path).convert("RGB") + + if self.transform is not None: + sample = self.transform(sample) + + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + + def _check_exists(self) -> bool: + return self._target_folder.is_dir() + + def download(self) -> None: + if self._check_exists(): + return + + base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" + + if self._split == "train": + download_and_extract_archive( + f"{base_url}GTSRB-Training_fixed.zip", + download_root=str(self._base_folder), + md5="513f3c79a4c5141765e10e952eaa2478", + ) + else: + download_and_extract_archive( + f"{base_url}GTSRB_Final_Test_Images.zip", + download_root=str(self._base_folder), + md5="c7e4e6327067d32654124b0fe9e82185", + ) + download_and_extract_archive( + f"{base_url}GTSRB_Final_Test_GT.zip", + download_root=str(self._base_folder), + md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", + ) + + +class GTSRB: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + # to fit with repo conventions for location + self.train_dataset = PyTorchGTSRB( + root=location, + download=True, + split='train', + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = PyTorchGTSRB( + root=location, + download=True, + split='test', + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md + self.classnames = [ + 'red and white circle 20 kph speed limit', + 'red and white circle 30 kph speed limit', + 'red and white circle 50 kph speed limit', + 'red and white circle 60 kph speed limit', + 'red and white circle 70 kph speed limit', + 'red and white circle 80 kph speed limit', + 'end / de-restriction of 80 kph speed limit', + 'red and white circle 100 kph speed limit', + 'red and white circle 120 kph speed limit', + 'red and white circle red car and black car no passing', + 'red and white circle red truck and black car no passing', + 'red and white triangle road intersection warning', + 'white and yellow diamond priority road', + 'red and white upside down triangle yield right-of-way', + 'stop', + 'empty red and white circle', + 'red and white circle no truck entry', + 'red circle with white horizonal stripe no entry', + 'red and white triangle with exclamation mark warning', + 'red and white triangle with black left curve approaching warning', + 'red and white triangle with black right curve approaching warning', + 'red and white triangle with black double curve approaching warning', + 'red and white triangle rough / bumpy road warning', + 'red and white triangle car skidding / slipping warning', + 'red and white triangle with merging / narrow lanes warning', + 'red and white triangle with person digging / construction / road work warning', + 'red and white triangle with traffic light approaching warning', + 'red and white triangle with person walking warning', + 'red and white triangle with child and person walking warning', + 'red and white triangle with bicyle warning', + 'red and white triangle with snowflake / ice warning', + 'red and white triangle with deer warning', + 'white circle with gray strike bar no speed limit', + 'blue circle with white right turn arrow mandatory', + 'blue circle with white left turn arrow mandatory', + 'blue circle with white forward arrow mandatory', + 'blue circle with white forward or right turn arrow mandatory', + 'blue circle with white forward or left turn arrow mandatory', + 'blue circle with white keep right arrow mandatory', + 'blue circle with white keep left arrow mandatory', + 'blue circle with white arrows indicating a traffic circle', + 'white circle with gray strike bar indicating no passing for cars has ended', + 'white circle with gray strike bar indicating no passing for trucks has ended', + ] diff --git a/task_vectors/src/datasets/imagenet.py b/task_vectors/src/datasets/imagenet.py new file mode 100644 index 0000000..4fbdcc4 --- /dev/null +++ b/task_vectors/src/datasets/imagenet.py @@ -0,0 +1,253 @@ +import os +import torch + +from .common import ImageFolderWithPaths, SubsetSampler +import numpy as np + + +imagenet_classnames = [ + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +] + +class ImageNet: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=32): + self.preprocess = preprocess + self.location = location + self.batch_size = batch_size + self.num_workers = num_workers + self.classnames = imagenet_classnames + + self.populate_train() + self.populate_test() + + def populate_train(self): + traindir = os.path.join(self.location, self.name(), 'train') + self.train_dataset = ImageFolderWithPaths( + traindir, + transform=self.preprocess) + sampler = self.get_train_sampler() + kwargs = {'shuffle' : True} if sampler is None else {} + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + sampler=sampler, + batch_size=self.batch_size, + num_workers=self.num_workers, + **kwargs, + ) + + def populate_test(self): + self.test_dataset = self.get_test_dataset() + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + sampler=self.get_test_sampler() + ) + + def get_test_path(self): + test_path = os.path.join(self.location, self.name(), 'val_in_folder') + if not os.path.exists(test_path): + test_path = os.path.join(self.location, self.name(), 'val') + return test_path + + def get_train_sampler(self): + return None + + def get_test_sampler(self): + return None + + def get_test_dataset(self): + return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) + + def name(self): + return 'imagenet' + +class ImageNetTrain(ImageNet): + + def get_test_dataset(self): + pass + +class ImageNetK(ImageNet): + + def get_train_sampler(self): + idxs = np.zeros(len(self.train_dataset.targets)) + target_array = np.array(self.train_dataset.targets) + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:self.k()] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype('int') + sampler = SubsetSampler(np.where(idxs)[0]) + return sampler \ No newline at end of file diff --git a/task_vectors/src/datasets/mnist.py b/task_vectors/src/datasets/mnist.py new file mode 100644 index 0000000..dd48193 --- /dev/null +++ b/task_vectors/src/datasets/mnist.py @@ -0,0 +1,41 @@ +import os +import torch +import torchvision.datasets as datasets + +class MNIST: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + + self.train_dataset = datasets.MNIST( + root=location, + download=True, + train=True, + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = datasets.MNIST( + root=location, + download=True, + train=False, + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] \ No newline at end of file diff --git a/task_vectors/src/datasets/registry.py b/task_vectors/src/datasets/registry.py new file mode 100644 index 0000000..ba07412 --- /dev/null +++ b/task_vectors/src/datasets/registry.py @@ -0,0 +1,100 @@ +import sys +import inspect +import random +import torch +import copy + +from torch.utils.data.dataset import random_split + +from src.datasets.cars import Cars +from src.datasets.cifar10 import CIFAR10 +from src.datasets.cifar100 import CIFAR100 +from src.datasets.dtd import DTD +from src.datasets.eurosat import EuroSAT, EuroSATVal +from src.datasets.gtsrb import GTSRB +from src.datasets.imagenet import ImageNet +from src.datasets.mnist import MNIST +from src.datasets.resisc45 import RESISC45 +from src.datasets.stl10 import STL10 +from src.datasets.svhn import SVHN +from src.datasets.sun397 import SUN397 + +registry = { + name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) +} + + +class GenericDataset(object): + def __init__(self): + self.train_dataset = None + self.train_loader = None + self.test_dataset = None + self.test_loader = None + self.classnames = None + + +def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0): + assert val_fraction > 0. and val_fraction < 1. + total_size = len(dataset.train_dataset) + val_size = int(total_size * val_fraction) + if max_val_samples is not None: + val_size = min(val_size, max_val_samples) + train_size = total_size - val_size + + assert val_size > 0 + assert train_size > 0 + + lengths = [train_size, val_size] + + trainset, valset = random_split( + dataset.train_dataset, + lengths, + generator=torch.Generator().manual_seed(seed) + ) + if new_dataset_class_name == 'MNISTVal': + assert trainset.indices[0] == 36044 + + + new_dataset = None + + new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) + new_dataset = new_dataset_class() + + new_dataset.train_dataset = trainset + new_dataset.train_loader = torch.utils.data.DataLoader( + new_dataset.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + new_dataset.test_dataset = valset + new_dataset.test_loader = torch.utils.data.DataLoader( + new_dataset.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + + new_dataset.classnames = copy.copy(dataset.classnames) + + return new_dataset + + +def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000): + if dataset_name.endswith('Val'): + # Handle val splits + if dataset_name in registry: + dataset_class = registry[dataset_name] + else: + base_dataset_name = dataset_name.split('Val')[0] + base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) + dataset = split_train_into_train_val( + base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples) + return dataset + else: + assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' + dataset_class = registry[dataset_name] + dataset = dataset_class( + preprocess, location=location, batch_size=batch_size, num_workers=num_workers + ) + return dataset diff --git a/task_vectors/src/datasets/resisc45.py b/task_vectors/src/datasets/resisc45.py new file mode 100644 index 0000000..056122d --- /dev/null +++ b/task_vectors/src/datasets/resisc45.py @@ -0,0 +1,304 @@ +import os +import torch + +import abc +import os +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import default_loader as pil_loader + + +# modified from: https://github.com/microsoft/torchgeo +class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): + """Abstract base class for datasets lacking geospatial information. + This base class is designed for datasets with pre-defined image chips. + """ + + @abc.abstractmethod + def __getitem__(self, index: int) -> Dict[str, Any]: + """Return an index within the dataset. + Args: + index: index to return + Returns: + data and labels at that index + Raises: + IndexError: if index is out of range of the dataset + """ + + @abc.abstractmethod + def __len__(self) -> int: + """Return the length of the dataset. + Returns: + length of the dataset + """ + + def __str__(self) -> str: + """Return the informal string representation of the object. + Returns: + informal string representation + """ + return f"""\ +{self.__class__.__name__} Dataset + type: VisionDataset + size: {len(self)}""" + + +class VisionClassificationDataset(VisionDataset, ImageFolder): + """Abstract base class for classification datasets lacking geospatial information. + This base class is designed for datasets with pre-defined image chips which + are separated into separate folders per class. + """ + + def __init__( + self, + root: str, + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + loader: Optional[Callable[[str], Any]] = pil_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> None: + """Initialize a new VisionClassificationDataset instance. + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + loader: a callable function which takes as input a path to an image and + returns a PIL Image or numpy array + is_valid_file: A function that takes the path of an Image file and checks if + the file is a valid file + """ + # When transform & target_transform are None, ImageFolder.__getitem__(index) + # returns a PIL.Image and int for image and label, respectively + super().__init__( + root=root, + transform=None, + target_transform=None, + loader=loader, + is_valid_file=is_valid_file, + ) + + # Must be set after calling super().__init__() + self.transforms = transforms + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + Args: + index: index to return + Returns: + data and label at that index + """ + image, label = self._load_image(index) + + if self.transforms is not None: + return self.transforms(image), label + + return image, label + + def __len__(self) -> int: + """Return the number of data points in the dataset. + Returns: + length of the dataset + """ + return len(self.imgs) + + def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: + """Load a single image and it's class label. + Args: + index: index to return + Returns: + the image + the image class label + """ + img, label = ImageFolder.__getitem__(self, index) + label = torch.tensor(label) + return img, label + + +class RESISC45Dataset(VisionClassificationDataset): + """RESISC45 dataset. + The `RESISC45 `_ + dataset is a dataset for remote sensing image scene classification. + Dataset features: + * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) + * three spectral bands - RGB + * 45 scene classes, 700 images per class + * images extracted from Google Earth from over 100 countries + * images conditions with high variability (resolution, weather, illumination) + Dataset format: + * images are three-channel jpgs + Dataset classes: + 0. airplane + 1. airport + 2. baseball_diamond + 3. basketball_court + 4. beach + 5. bridge + 6. chaparral + 7. church + 8. circular_farmland + 9. cloud + 10. commercial_area + 11. dense_residential + 12. desert + 13. forest + 14. freeway + 15. golf_course + 16. ground_track_field + 17. harbor + 18. industrial_area + 19. intersection + 20. island + 21. lake + 22. meadow + 23. medium_residential + 24. mobile_home_park + 25. mountain + 26. overpass + 27. palace + 28. parking_lot + 29. railway + 30. railway_station + 31. rectangular_farmland + 32. river + 33. roundabout + 34. runway + 35. sea_ice + 36. ship + 37. snowberg + 38. sparse_residential + 39. stadium + 40. storage_tank + 41. tennis_court + 42. terrace + 43. thermal_power_station + 44. wetland + This dataset uses the train/val/test splits defined in the "In-domain representation + learning for remote sensing" paper: + * https://arxiv.org/abs/1911.06721 + If you use this dataset in your research, please cite the following paper: + * https://doi.org/10.1109/jproc.2017.2675998 + """ + + # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" + # md5 = "d824acb73957502b00efd559fc6cfbbb" + # filename = "NWPU-RESISC45.rar" + directory = "resisc45/NWPU-RESISC45" + + splits = ["train", "val", "test"] + split_urls = { + "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 + "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 + "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 + } + split_md5s = { + "train": "b5a4c05a37de15e4ca886696a85c403e", + "val": "a0770cee4c5ca20b8c32bbd61e114805", + "test": "3dda9e4988b47eb1de9f07993653eb08", + } + classes = [ + "airplane", + "airport", + "baseball_diamond", + "basketball_court", + "beach", + "bridge", + "chaparral", + "church", + "circular_farmland", + "cloud", + "commercial_area", + "dense_residential", + "desert", + "forest", + "freeway", + "golf_course", + "ground_track_field", + "harbor", + "industrial_area", + "intersection", + "island", + "lake", + "meadow", + "medium_residential", + "mobile_home_park", + "mountain", + "overpass", + "palace", + "parking_lot", + "railway", + "railway_station", + "rectangular_farmland", + "river", + "roundabout", + "runway", + "sea_ice", + "ship", + "snowberg", + "sparse_residential", + "stadium", + "storage_tank", + "tennis_court", + "terrace", + "thermal_power_station", + "wetland", + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + ) -> None: + """Initialize a new RESISC45 dataset instance. + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + """ + assert split in self.splits + self.root = root + + valid_fns = set() + with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: + for fn in f: + valid_fns.add(fn.strip()) + is_in_split: Callable[[str], bool] = lambda x: os.path.basename( + x) in valid_fns + + super().__init__( + root=os.path.join(root, self.directory), + transforms=transforms, + is_valid_file=is_in_split, + ) + + + +class RESISC45: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + + self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + + # class names have _ so split on this for better zero-shot head + self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] diff --git a/task_vectors/src/datasets/stl10.py b/task_vectors/src/datasets/stl10.py new file mode 100644 index 0000000..0c7237f --- /dev/null +++ b/task_vectors/src/datasets/stl10.py @@ -0,0 +1,41 @@ +import os +import torch +import torchvision.datasets as datasets + +class STL10: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + location = os.path.join(location, 'stl10') + self.train_dataset = datasets.STL10( + root=location, + download=True, + split='train', + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = datasets.STL10( + root=location, + download=True, + split='test', + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + self.classnames = self.train_dataset.classes \ No newline at end of file diff --git a/task_vectors/src/datasets/sun397.py b/task_vectors/src/datasets/sun397.py new file mode 100644 index 0000000..684c648 --- /dev/null +++ b/task_vectors/src/datasets/sun397.py @@ -0,0 +1,32 @@ +import os +import torch +import torchvision.datasets as datasets + +class SUN397: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + # Data loading code + traindir = os.path.join(location, 'sun397', 'train') + valdir = os.path.join(location, 'sun397', 'val') + + + self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))] diff --git a/task_vectors/src/datasets/svhn.py b/task_vectors/src/datasets/svhn.py new file mode 100644 index 0000000..0e9b47c --- /dev/null +++ b/task_vectors/src/datasets/svhn.py @@ -0,0 +1,45 @@ +import os +import torch +from torchvision.datasets import SVHN as PyTorchSVHN +import numpy as np + + +class SVHN: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + # to fit with repo conventions for location + modified_location = os.path.join(location, 'svhn') + + self.train_dataset = PyTorchSVHN( + root=modified_location, + download=True, + split='train', + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = PyTorchSVHN( + root=modified_location, + download=True, + split='test', + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] diff --git a/task_vectors/src/datasets/templates.py b/task_vectors/src/datasets/templates.py new file mode 100644 index 0000000..f53a3c8 --- /dev/null +++ b/task_vectors/src/datasets/templates.py @@ -0,0 +1,225 @@ +cars_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'i love my {c}!', + lambda c: f'a photo of my dirty {c}.', + lambda c: f'a photo of my clean {c}.', + lambda c: f'a photo of my new {c}.', + lambda c: f'a photo of my old {c}.', +] + +cifar10_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +cifar100_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +dtd_template = [ + lambda c: f'a photo of a {c} texture.', + lambda c: f'a photo of a {c} pattern.', + lambda c: f'a photo of a {c} thing.', + lambda c: f'a photo of a {c} object.', + lambda c: f'a photo of the {c} texture.', + lambda c: f'a photo of the {c} pattern.', + lambda c: f'a photo of the {c} thing.', + lambda c: f'a photo of the {c} object.', +] + +eurosat_template = [ + lambda c: f'a centered satellite photo of {c}.', + lambda c: f'a centered satellite photo of a {c}.', + lambda c: f'a centered satellite photo of the {c}.', +] + +food101_template = [ + lambda c: f'a photo of {c}, a type of food.', +] + +gtsrb_template = [ + lambda c: f'a zoomed in photo of a "{c}" traffic sign.', + lambda c: f'a centered photo of a "{c}" traffic sign.', + lambda c: f'a close up photo of a "{c}" traffic sign.', +] + +mnist_template = [ + lambda c: f'a photo of the number: "{c}".', +] + +imagenet_template = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] + +resisc45_template = [ + lambda c: f'satellite imagery of {c}.', + lambda c: f'aerial imagery of {c}.', + lambda c: f'satellite photo of {c}.', + lambda c: f'aerial photo of {c}.', + lambda c: f'satellite view of {c}.', + lambda c: f'aerial view of {c}.', + lambda c: f'satellite imagery of a {c}.', + lambda c: f'aerial imagery of a {c}.', + lambda c: f'satellite photo of a {c}.', + lambda c: f'aerial photo of a {c}.', + lambda c: f'satellite view of a {c}.', + lambda c: f'aerial view of a {c}.', + lambda c: f'satellite imagery of the {c}.', + lambda c: f'aerial imagery of the {c}.', + lambda c: f'satellite photo of the {c}.', + lambda c: f'aerial photo of the {c}.', + lambda c: f'satellite view of the {c}.', + lambda c: f'aerial view of the {c}.', +] + +stl10_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a photo of the {c}.', +] + +sun397_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a photo of the {c}.', +] + +svhn_template = [ + lambda c: f'a photo of the number: "{c}".', +] + + +dataset_to_template = { + 'Cars': cars_template, + 'CIFAR10': cifar10_template, + 'CIFAR100': cifar100_template, + 'DTD': dtd_template, + 'EuroSAT': eurosat_template, + 'Food101': food101_template, + 'GTSRB': gtsrb_template, + 'MNIST': mnist_template, + 'ImageNet': imagenet_template, + 'RESISC45': resisc45_template, + 'STL10': stl10_template, + 'SUN397': sun397_template, + 'SVHN': svhn_template, +} + + +def get_templates(dataset_name): + if dataset_name.endswith('Val'): + return get_templates(dataset_name.replace('Val', '')) + assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' + return dataset_to_template[dataset_name] \ No newline at end of file diff --git a/task_vectors/src/eval.py b/task_vectors/src/eval.py new file mode 100644 index 0000000..b215992 --- /dev/null +++ b/task_vectors/src/eval.py @@ -0,0 +1,80 @@ +import os +import json +import tqdm + +import torch +import numpy as np + +from src import utils +from src.datasets.common import get_dataloader, maybe_dictionarize +from src.heads import get_classification_head +from src.modeling import ImageClassifier + +from src.datasets.registry import get_dataset + + +def eval_single_dataset(image_encoder, dataset_name, args): + classification_head = get_classification_head(args, dataset_name) + model = ImageClassifier(image_encoder, classification_head) + + model.eval() + + dataset = get_dataset( + dataset_name, + model.val_preprocess, + location=args.data_location, + batch_size=args.batch_size + ) + dataloader = get_dataloader( + dataset, is_train=False, args=args, image_encoder=None) + device = args.device + + with torch.no_grad(): + top1, correct, n = 0., 0., 0. + for i, data in enumerate(tqdm.tqdm(dataloader)): + data = maybe_dictionarize(data) + x = data['images'].to(device) + y = data['labels'].to(device) + + logits = utils.get_logits(x, model) + + pred = logits.argmax(dim=1, keepdim=True).to(device) + + correct += pred.eq(y.view_as(pred)).sum().item() + + n += y.size(0) + + top1 = correct / n + + metrics = {'top1': top1} + print(f'Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%') + + return metrics + +def evaluate(image_encoder, args): + if args.eval_datasets is None: + return + info = vars(args) + for i, dataset_name in enumerate(args.eval_datasets): + print('Evaluating on', dataset_name) + + results = eval_single_dataset(image_encoder, dataset_name, args) + + if 'top1' in results: + print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") + for key, val in results.items(): + if 'worst' in key or 'f1' in key.lower() or 'pm0' in key: + print(f"{dataset_name} {key}: {val:.4f}") + info[dataset_name + ':' + key] = val + + if args.results_db is not None: + dirname = os.path.dirname(args.results_db) + if dirname: + os.makedirs(dirname, exist_ok=True) + with open(args.results_db, 'a+') as f: + f.write(json.dumps(info) + '\n') + print(f'Results saved to {args.results_db}.') + else: + print('Results not saved (to do so, use --results_db to specify a path).') + + return info \ No newline at end of file diff --git a/task_vectors/src/finetune.py b/task_vectors/src/finetune.py new file mode 100644 index 0000000..6085b28 --- /dev/null +++ b/task_vectors/src/finetune.py @@ -0,0 +1,149 @@ +import os +import time + +import torch + +from src.args import parse_arguments +from src.datasets.common import get_dataloader, maybe_dictionarize +from src.datasets.registry import get_dataset +from src.eval import evaluate +from src.modeling import ImageEncoder, ImageClassifier, MultiHeadImageClassifier +from src.utils import cosine_lr, LabelSmoothing +from src.heads import get_classification_head + + +import src.datasets as datasets + + +def finetune(args): + train_dataset = args.train_dataset + ckpdir = os.path.join(args.save, train_dataset) + + # Check if checkpoints already exist + zs_path = os.path.join(args.save, train_dataset, 'checkpoint_0.pt') + ft_path = os.path.join(args.save, train_dataset, f'checkpoint_{args.epochs}.pt') + if os.path.exists(zs_path) and os.path.exists(ft_path): + print(f'Skipping fine-tuning because {ft_path} exists.') + return zs_path, ft_path + + assert train_dataset is not None, "Please provide a training dataset." + if args.load is not None and args.load.endswith('pt'): + image_encoder = ImageEncoder.load(args.load) + else: + print('Building image encoder.') + image_encoder = ImageEncoder(args, keep_lang=False) + + classification_head = get_classification_head(args, train_dataset) + + model = ImageClassifier(image_encoder, classification_head) + + model.freeze_head() + + preprocess_fn = model.train_preprocess + print_every = 100 + + dataset = get_dataset( + train_dataset, + preprocess_fn, + location=args.data_location, + batch_size=args.batch_size + ) + num_batches = len(dataset.train_loader) + + devices = list(range(torch.cuda.device_count())) + print('Using devices', devices) + model = torch.nn.DataParallel(model, device_ids=devices) + + if args.ls > 0: + loss_fn = LabelSmoothing(args.ls) + else: + loss_fn = torch.nn.CrossEntropyLoss() + + params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) + + scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches) + + # Saving zero-shot model + if args.save is not None: + os.makedirs(ckpdir, exist_ok=True) + model_path = os.path.join(ckpdir, f'zeroshot.pt') + model.module.image_encoder.save(model_path) + + for epoch in range(args.epochs): + model = model.cuda() + model.train() + data_loader = get_dataloader( + dataset, is_train=True, args=args, image_encoder=None) + + for i, batch in enumerate(data_loader): + start_time = time.time() + + step = i + epoch * num_batches + scheduler(step) + optimizer.zero_grad() + + batch = maybe_dictionarize(batch) + inputs = batch['images'].to('cuda:0') + labels = batch['labels'].to('cuda:0') + data_time = time.time() - start_time + + logits = model(inputs) + + loss = loss_fn(logits, labels) + + loss.backward() + + torch.nn.utils.clip_grad_norm_(params, 1.0) + + optimizer.step() + batch_time = time.time() - start_time + + if step % print_every == 0: + percent_complete = 100 * i / len(data_loader) + print( + f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t" + f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True + ) + + # Evaluate + image_encoder = model.module.image_encoder + evaluate(image_encoder, args) + + if args.save is not None: + zs_path = os.path.join(ckpdir, 'zeroshot.pt') + ft_path = os.path.join(ckpdir, 'finetuned.pt') + image_encoder.save(ft_path) + return zs_path, ft_path + + +if __name__ == '__main__': + data_location = '' + models = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14'] + datasets = ['Cars', 'DTD', 'EuroSAT', 'GTSRB', 'MNIST', 'RESISC45', 'SUN397', 'SVHN'] + epochs = { + 'Cars': 35, + 'DTD': 76, + 'EuroSAT': 12, + 'GTSRB': 11, + 'MNIST': 5, + 'RESISC45': 15, + 'SUN397': 14, + 'SVHN': 4, + 'ImageNet': 4 + } + + for model in models: + for dataset in datasets: + print('='*100) + print(f'Finetuning {model} on {dataset}') + print('='*100) + args = parse_arguments() + args.lr = 1e-5 + args.epochs = epochs[dataset] + args.data_location = data_location + args.train_dataset = dataset + 'Val' + args.batch_size = 128 + args.model = model + args.save = f'checkpoints/{model}' + finetune(args) diff --git a/task_vectors/src/heads.py b/task_vectors/src/heads.py new file mode 100644 index 0000000..1287b8b --- /dev/null +++ b/task_vectors/src/heads.py @@ -0,0 +1,66 @@ +import os +import torch +from tqdm import tqdm + +import open_clip + +from src.datasets.templates import get_templates +from src.datasets.registry import get_dataset + +from src.modeling import ClassificationHead, ImageEncoder + + +def build_classification_head(model, dataset_name, template, data_location, device): + template = get_templates(dataset_name) + + logit_scale = model.logit_scale + dataset = get_dataset( + dataset_name, + None, + location=data_location + ) + model.eval() + model.to(device) + + print('Building classification head.') + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(dataset.classnames): + texts = [] + for t in template: + texts.append(t(classname)) + texts = open_clip.tokenize(texts).to(device) # tokenize + embeddings = model.encode_text(texts) # embed with text encoder + embeddings /= embeddings.norm(dim=-1, keepdim=True) + + embeddings = embeddings.mean(dim=0, keepdim=True) + embeddings /= embeddings.norm() + + zeroshot_weights.append(embeddings) + + zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) + zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) + + zeroshot_weights *= logit_scale.exp() + + zeroshot_weights = zeroshot_weights.squeeze().float() + zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) + + classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) + + return classification_head + + +def get_classification_head(args, dataset): + filename = os.path.join(args.save, f'head_{dataset}.pt') + if os.path.exists(filename): + print(f'Classification head for {args.model} on {dataset} exists at {filename}') + return ClassificationHead.load(filename) + print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') + model = ImageEncoder(args, keep_lang=True).model + template = get_templates(dataset) + classification_head = build_classification_head(model, dataset, template, args.data_location, args.device) + os.makedirs(args.save, exist_ok=True) + classification_head.save(filename) + return classification_head + diff --git a/task_vectors/src/modeling.py b/task_vectors/src/modeling.py new file mode 100644 index 0000000..2159e40 --- /dev/null +++ b/task_vectors/src/modeling.py @@ -0,0 +1,142 @@ +import torch + +import open_clip + +from src import utils + + +class ImageEncoder(torch.nn.Module): + def __init__(self, args, keep_lang=False): + super().__init__() + + print(f'Loading {args.model} pre-trained weights.') + if '__pretrained__' in args.model: + name, pretrained = args.model.split('__pretrained__') + else: + name = args.model + pretrained = 'openai' + self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( + name, pretrained=pretrained, cache_dir=args.openclip_cachedir) + + self.cache_dir = args.cache_dir + + if not keep_lang and hasattr(self.model, 'transformer'): + delattr(self.model, 'transformer') + + def forward(self, images): + assert self.model is not None + return self.model.encode_image(images) + + def __call__(self, inputs): + return self.forward(inputs) + + def save(self, filename): + print(f'Saving image encoder to {filename}') + utils.torch_save(self, filename) + + @classmethod + def load(cls, model_name, filename): + print(f'Loading image encoder from {filename}') + state_dict = torch.load(filename) + return cls.load(model_name, state_dict) + + @classmethod + def load_from_state_dict(cls, model_name, state_dict): + self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( + name, pretrained=pretrained, cache_dir=args.openclip_cachedir) + self.model.load_from_state_dict(state_dict) + + + + +class ClassificationHead(torch.nn.Linear): + def __init__(self, normalize, weights, biases=None): + output_size, input_size = weights.shape + super().__init__(input_size, output_size) + self.normalize = normalize + if weights is not None: + self.weight = torch.nn.Parameter(weights.clone()) + if biases is not None: + self.bias = torch.nn.Parameter(biases.clone()) + else: + self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) + + def forward(self, inputs): + if self.normalize: + inputs = inputs / inputs.norm(dim=-1, keepdim=True) + return super().forward(inputs) + + def __call__(self, inputs): + return self.forward(inputs) + + def save(self, filename): + print(f'Saving classification head to {filename}') + utils.torch_save(self, filename) + + @classmethod + def load(cls, filename): + print(f'Loading classification head from {filename}') + return utils.torch_load(filename) + + +class ImageClassifier(torch.nn.Module): + def __init__(self, image_encoder, classification_head): + super().__init__() + self.image_encoder = image_encoder + self.classification_head = classification_head + if self.image_encoder is not None: + self.train_preprocess = self.image_encoder.train_preprocess + self.val_preprocess = self.image_encoder.val_preprocess + + def freeze_head(self): + self.classification_head.weight.requires_grad_(False) + self.classification_head.bias.requires_grad_(False) + + def forward(self, inputs): + features = self.image_encoder(inputs) + outputs = self.classification_head(features) + return outputs + + def __call__(self, inputs): + return self.forward(inputs) + + def save(self, filename): + print(f'Saving image classifier to {filename}') + utils.torch_save(self, filename) + + @classmethod + def load(cls, filename): + print(f'Loading image classifier from {filename}') + return utils.torch_load(filename) + + +class MultiHeadImageClassifier(torch.nn.Module): + def __init__(self, image_encoder, classification_heads): + super().__init__() + self.image_encoder = image_encoder + self.classification_heads = torch.nn.ModuleList(classification_heads) + if self.image_encoder is not None: + self.train_preprocess = self.image_encoder.train_preprocess + self.val_preprocess = self.image_encoder.val_preprocess + + def freeze_head(self): + for idx in range(len(self.classification_heads)): + self.classification_heads[idx].weight.requires_grad_(False) + self.classification_heads[idx].bias.requires_grad_(False) + + def forward(self, inputs, head_idx): + features = self.image_encoder(inputs) + outputs = self.classification_heads[head_idx](features) + return outputs + + def __call__(self, inputs, head_idx): + return self.forward(inputs, head_idx) + + def save(self, filename): + print(f'Saving image classifier to {filename}') + utils.torch_save(self, filename) + + @classmethod + def load(cls, filename): + print(f'Loading image classifier from {filename}') + return utils.torch_load(filename) diff --git a/task_vectors/src/task_vectors.py b/task_vectors/src/task_vectors.py new file mode 100644 index 0000000..5b934d6 --- /dev/null +++ b/task_vectors/src/task_vectors.py @@ -0,0 +1,83 @@ +import torch +import argparse +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from transformers.utils import cached_file +import pdb +from tqdm import tqdm +class TaskVector(): + def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None): + """Initializes the task vector from a pretrained and a finetuned checkpoints. + + This can either be done by passing two state dicts (one corresponding to the + pretrained model, and another to the finetuned model), or by directly passying in + the task vector state dict. + """ + if vector is not None: + self.vector = vector + else: + assert pretrained_checkpoint is not None and finetuned_checkpoint is not None + with torch.no_grad(): + pretrained_state_dict = AutoModelForCausalLM.from_pretrained(pretrained_checkpoint).state_dict() + finetuned_state_dict = AutoModelForCausalLM.from_pretrained(finetuned_checkpoint).state_dict() + #unbiased_state_dict = AutoModelForCausalLM.from_pretrained("").state_dict() + pdb.set_trace() + self.vector = {} + for key in tqdm(pretrained_state_dict): + if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]: + continue + self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] + + def __add__(self, other): + """Add two task vectors together.""" + with torch.no_grad(): + new_vector = {} + for key in self.vector: + if key not in other.vector: + continue + new_vector[key] = self.vector[key] + other.vector[key] + return TaskVector(vector=new_vector) + + def __radd__(self, other): + if other is None or isinstance(other, int): + return self + return self.__add__(other) + + def __neg__(self): + """Negate a task vector.""" + with torch.no_grad(): + new_vector = {} + for key in self.vector: + new_vector[key] = - self.vector[key] + return TaskVector(vector=new_vector) + + def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): + """Apply a task vector to a pretrained model.""" + with torch.no_grad(): + pretrained_model = AutoModelForCausalLM.from_pretrained(pretrained_checkpoint) + new_state_dict = {} + pretrained_state_dict = pretrained_model.state_dict() + for key in tqdm(pretrained_state_dict): + if key not in self.vector: + print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector') + continue + new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key] + pretrained_model.load_state_dict(new_state_dict, strict=False) + return pretrained_model + +if __name__ == "__main__": + # Create the parser + parser = argparse.ArgumentParser() + + parser.add_argument('--pretrained_model_path', type=str, required=True) + parser.add_argument('--finetuned_model_path', type=str, required=True) + parser.add_argument('--scaling_coef', type=float, required=True) + parser.add_argument('--model_save_path', type = str, required=True) + args = parser.parse_args() + + vector = TaskVector(args.pretrained_model_path, args.finetuned_model_path) + print("TaskVector created") + neg_tv = -vector + unbiased_model = neg_tv.apply_to(args.pretrained_model_path, args.scaling_coef) + print("unbiased_model created") + unbiased_model.save_pretrained(args.model_save_path) + print("unbiased_model saved") \ No newline at end of file diff --git a/task_vectors/src/utils.py b/task_vectors/src/utils.py new file mode 100644 index 0000000..6a875b5 --- /dev/null +++ b/task_vectors/src/utils.py @@ -0,0 +1,91 @@ +import os + +import torch +import pickle +from tqdm import tqdm +import math + +import numpy as np + + +def assign_learning_rate(param_group, new_lr): + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lrs, warmup_length, steps): + if not isinstance(base_lrs, list): + base_lrs = [base_lrs for _ in optimizer.param_groups] + assert len(base_lrs) == len(optimizer.param_groups) + def _lr_adjuster(step): + for param_group, base_lr in zip(optimizer.param_groups, base_lrs): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(param_group, lr) + return _lr_adjuster + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def torch_load_old(save_path, device=None): + with open(save_path, 'rb') as f: + classifier = pickle.load(f) + if device is not None: + classifier = classifier.to(device) + return classifier + + +def torch_save(model, save_path): + if os.path.dirname(save_path) != '': + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save(model.cpu(), save_path) + + +def torch_load(save_path, device=None): + model = torch.load(save_path) + if device is not None: + model = model.to(device) + return model + + + +def get_logits(inputs, classifier): + assert callable(classifier) + if hasattr(classifier, 'to'): + classifier = classifier.to(inputs.device) + return classifier(inputs) + + +def get_probs(inputs, classifier): + if hasattr(classifier, 'predict_proba'): + probs = classifier.predict_proba(inputs.detach().cpu().numpy()) + return torch.from_numpy(probs) + logits = get_logits(inputs, classifier) + return logits.softmax(dim=1) + + +class LabelSmoothing(torch.nn.Module): + def __init__(self, smoothing=0.0): + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() diff --git a/task_vectors/train_model.py b/task_vectors/train_model.py new file mode 100644 index 0000000..3185de7 --- /dev/null +++ b/task_vectors/train_model.py @@ -0,0 +1,102 @@ +import transformers +from datasets import load_dataset +import sys +sys.path.append('/h/ws_tyau/bias-lm-stream/bias_mitigation') +from get_dataset import CustomDataset +import os +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +import argparse +from peft import get_peft_model, LoraConfig, IA3Config, TaskType + +parser = argparse.ArgumentParser() + +parser.add_argument('--model_name', type=str, required=True) +parser.add_argument('--output_path', type=str, required=True) +parser.add_argument('--pem', type=str, default=None) + +args = parser.parse_args() + +data = CustomDataset("stereo_dpo") +data = data.get_dataset() +model = AutoModelForCausalLM.from_pretrained( + args.model_name, + device_map='auto', +) + +if args.pem: + if args.pem == 'lora_adapter': + llama_peft_config = LoraConfig( + task_type="CAUSAL_LM", + r=8, + lora_dropout=0.01, + ) + elif args.pem == 'ia3_adapter': + llama_peft_config = IA3Config( + peft_type="IA3", + task_type="CAUSAL_LM", + ) + else: + raise ValueError("Invalid PEM type for Llama model.") + model = get_peft_model(model, llama_peft_config) + model.print_trainable_parameters() + + +tokenizer = AutoTokenizer.from_pretrained(args.model_name) +if 'Llama' in args.model_name: + model.config.pad_token_id = model.config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + +data = data.map(lambda samples: tokenizer(samples['text']), batched=True) + +output_dir = args.output_path + args.model_name + "-" + str(args.pem) + +if 'Llama' in args.model_name: + trainer = transformers.Trainer( + model=model, + train_dataset=data, + eval_dataset=data, + args=transformers.TrainingArguments( + per_device_train_batch_size=2, + gradient_accumulation_steps=8, + warmup_steps=100, + num_train_epochs=10, + learning_rate= 5e-4, + logging_steps=10, + save_strategy="epoch", + evaluation_strategy="epoch", + metric_for_best_model="loss", + output_dir=output_dir, + load_best_model_at_end=True, + ), + data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False) + ) +elif 'opt' in args.model_name: + trainer = transformers.Trainer( + model=model, + train_dataset=data, + eval_dataset=data, + args=transformers.TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=100, + num_train_epochs=10, + learning_rate=2e-4, + logging_steps=10, + save_strategy="epoch", + evaluation_strategy="epoch", + metric_for_best_model="loss", + output_dir=output_dir, + load_best_model_at_end=True, + ), + data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False) + ) +else: + raise ValueError("Invalid model name.") +trainer.train() + +model.save_pretrained(output_dir) +tokenizer.save_pretrained(output_dir) +