Skip to content

Commit

Permalink
Add task vector scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
xeon27 committed Oct 9, 2024
1 parent cc77fff commit 7cee731
Show file tree
Hide file tree
Showing 28 changed files with 2,919 additions and 0 deletions.
136 changes: 136 additions & 0 deletions task_vectors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Editing Models with Task Arithmetic

This repository contains code for the ICLR 2023 paper [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089), by Gabriel Ilharco, Marco Tulio Ribeiro, Mitchell Wortsman, Suchin Gururangan, Ludwig Schmidt, Hannaneh Hajishirzi and Ali Farhadi.

### Abstract
*Changing how pre-trained models behave---e.g., improving their performance on a downstream task or mitigating biases learned during pre-training---is a common practice when developing machine learning systems. In this work, we propose a new paradigm for steering the behavior of neural networks, centered around task vectors. A task vector specifies a direction in the weight space of a pre-trained model, such that movement in that direction improves performance on the task. We build task vectors by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning on a task. We show that these task vectors can be modified and combined together through arithmetic operations such as negation and addition, and the behavior of the resulting model is steered accordingly. Negating a task vector decreases performance on the target task, with little change in model behavior on control tasks. Moreover, adding task vectors together can improve performance on multiple tasks at once. Finally, when tasks are linked by an analogy relationship of the form ``A is to B as C is to D", combining task vectors from three of the tasks can improve performance on the fourth, even when no data from the fourth task is used for training. Overall, our experiments with several models, modalities and tasks show that task arithmetic is a simple, efficient and effective way of editing models.*


### Summary figure

<p align="center">
<img src="img/task_vectors.png" alt="scatter" width="100%"/>
</p>

An illustration of task vectors and the arithmetic operations we study for editing models. (a) A task vector is obtained by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning. (b) Negating a task vector degrades performance on the task, without substantial changes in control tasks. (c) Adding task vectors together improves the performance of the pre-trained model on the tasks under consideration. (d) When tasks form an analogy relationship such as supervised and unsupervised learning on two different data sources, it is possible to improve performance on a supervised target task using only vectors from the remaining three combinations of objectives and datasets.

## Code

### Install dependencies

```bash
conda env create
conda activate task-vectors
```


### Add directory to PYTHONPATH:

```bash
cd task_vectors
export PYTHONPATH="$PYTHONPATH:$PWD"
```

### Using task vectors

The task vector logic can be found at [src/task_vectors.py](src/task_vectors.py).

To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint:

```python
from task_vectors import TaskVector
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint)
```

Once created, task vectors can be modified and combined through arithmetic operations! For instance, to negate a task vector, simply use the ```-``` operator:

```python
# Negating a task vector
new_task_vector = -task_vector
```

To add task vectors, you can use the ```+``` operator, or ```sum```:

```python
# Adding two task vectors
new_task_vector = task_vector_A + task_vector_B
# Adding multiple task vectors
new_task_vector = sum(list_of_task_vectors)
```

Analogies can be done as simply as:

```python
# Task analogies
new_task_vector = task_vector_C + task_vector_B - task_vector_A
```

### Checkpoints

Checkpoints for CLIP ViT-B/32, ViT-B/16 and ViT-L/14 are available on he link below, including fine-tuned checkpoints on eight downstream tasks: Stanford Cars, DTD, EuroSAT, GTSRB, MNIST, RESISC45, SUN397 and SVHN.

[Download here](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw?usp=share_link)

### Examples

Below is an example of negating a task vector from MNIST, then evaluating on MNIST and on ImageNet:

```python
import torch
from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

# Config
dataset = 'MNIST'
model = 'ViT-L-14'
args = parse_arguments()
args.data_location = '/path/to/data'
args.model = model
args.save = f'checkpoints/{model}'
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt'
finetuned_checkpoint = f'checkpoints/{model}/{dataset}/finetuned.pt'


# Create the task vector
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint)
# Negate the task vector
neg_task_vector = -task_vector
# Apply the task vector
image_encoder = neg_task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.5)
# Evaluate
eval_single_dataset(image_encoder, dataset, args)
eval_single_dataset(image_encoder, 'ImageNet', args)
```

You can also find an example of adding task vectors together below, using the MNIST and RESISC45 datasets:


```python
import torch
from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

# Config
datasets = ['MNIST', 'RESISC45']
model = 'ViT-L-14'
args = parse_arguments()
args.data_location = '/path/to/data'
args.model = model
args.save = f'checkpoints/{model}'
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt'

# Create the task vectors
task_vectors = [
TaskVector(pretrained_checkpoint, f'checkpoints/{model}/{dataset}/finetuned.pt')
for dataset in datasets
]
# Sum the task vectors
task_vector_sum = sum(task_vectors)
# Apply the resulting task vector
image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=0.8)
# Evaluate
for dataset in datasets:
eval_single_dataset(image_encoder, dataset, args)
```
49 changes: 49 additions & 0 deletions task_vectors/download_pretrained_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import transformers
from datasets import load_dataset
import sys
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, IA3Config, TaskType

import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--pem', type=str, default=None)

args = parser.parse_args()


model = AutoModelForCausalLM.from_pretrained(
args.model_name,
device_map='auto',
)

if 'Llama' in args.model_name and args.pem is not None:
if args.pem == 'lora_adapter':
llama_peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=8,
lora_dropout=0.01,
)
elif args.pem == 'ia3_adapter':
llama_peft_config = IA3Config(
peft_type="IA3",
task_type="CAUSAL_LM",
)
else:
raise ValueError("Invalid PEM type for Llama model.")
model = get_peft_model(model, llama_peft_config)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if 'Llama' in args.model_name:
model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

model.save_pretrained("./outputs/pretrained/" + args.model_name)
tokenizer.save_pretrained("./outputs/pretrained/" + args.model_name)
108 changes: 108 additions & 0 deletions task_vectors/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
name: task-vectors
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotlipy=0.7.0=py310h5764c6d_1004
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.9.24=ha878542_0
- cffi=1.15.1=py310h74dc2b5_0
- cryptography=37.0.2=py310h597c629_0
- cudatoolkit=11.6.0=hecad31d_10
- ffmpeg=4.3=hf484d3e_0
- freetype=2.10.4=h0708190_1
- giflib=5.2.1=h36c2ea0_2
- gmp=6.2.1=h58526e2_0
- gnutls=3.6.13=h85f3911_1
- intel-openmp=2021.4.0=h06a4308_3561
- jbig=2.1=h7f98852_2003
- jpeg=9e=h166bdaf_1
- lame=3.100=h7f98852_1001
- lcms2=2.12=hddcbb42_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=2.2.1=h9c3ff4c_0
- libdeflate=1.7=h7f98852_5
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.17=h166bdaf_0
- libpng=1.6.37=h21135ba_2
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.3.0=hf544144_1
- libuuid=1.0.3=h7f8727e_2
- libwebp=1.2.2=h3452ae3_0
- libwebp-base=1.2.2=h7f98852_1
- lz4-c=1.9.3=h9c3ff4c_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py310ha2c4b55_0
- mkl_fft=1.3.1=py310h2b4bcf5_1
- mkl_random=1.2.2=py310h00e6091_0
- ncurses=6.3=h5eee18b_3
- nettle=3.6=he412f7d_0
- numpy-base=1.23.1=py310hcba007f_0
- openh264=2.1.1=h780b84a_0
- openssl=1.1.1o=h166bdaf_0
- pip=22.2.2=py310h06a4308_0
- python=3.10.4=h12debd9_0
- python_abi=3.10=2_cp310
- pytorch=1.12.1=py3.10_cuda11.6_cudnn8.3.2_0
- pytorch-mutex=1.0=cuda
- readline=8.1.2=h7f8727e_1
- setuptools=63.4.1=py310h06a4308_0
- sqlite=3.39.3=h5082296_0
- tk=8.6.12=h1ccaba5_0
- typing_extensions=4.3.0=pyha770c72_0
- tzdata=2022c=h04d1e81_0
- xz=5.2.6=h5eee18b_0
- zlib=1.2.12=h5eee18b_3
- zstd=1.5.0=ha95c52a_0
- pip:
- certifi==2022.9.24
- charset-normalizer==2.1.1
- contourpy==1.0.5
- cvxpy==1.2.2
- cycler==0.11.0
- ecos==2.0.10
- filelock==3.8.0
- fonttools==4.37.4
- ftfy==6.1.1
- huggingface-hub==0.10.0
- idna==3.4
- kiwisolver==1.4.4
- matplotlib==3.6.0
- numpy==1.23.3
- open-clip-torch==2.0.2
- osqp==0.6.2.post5
- packaging==21.3
- pandas==1.5.0
- patsy==0.5.2
- pillow==9.2.0
- plotly==5.11.0
- pycparser==2.21
- pyopenssl==22.0.0
- pyparsing==3.0.9
- pysocks==1.7.1
- python-dateutil==2.8.2
- pytz==2022.4
- pyyaml==6.0
- qdldl==0.1.5.post2
- regex==2022.9.13
- requests==2.28.1
- scipy==1.9.1
- scs==3.2.2
- seaborn==0.12.1
- six==1.16.0
- statsmodels==0.13.2
- tenacity==8.1.0
- torch==1.12.1
- torchaudio==0.12.1+cu116
- torchvision==0.13.1
- tqdm==4.64.1
- typing-extensions==4.3.0
- urllib3==1.26.12
- wcwidth==0.2.5
- wheel==0.37.1
Binary file added task_vectors/img/task_vectors.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 72 additions & 0 deletions task_vectors/negation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

import argparse
import logging
import os
import torch
import argparse
import os
from peft import PeftConfig, PeftModel


# import numpy as np
# import torch
# import random
# import pandas as pd
# import pdb
from transformers import AutoModelForCausalLM

def adapter_negation(
model,
adapter_config,
scale
):


state_dict = model.state_dict()
merged_keys = [mk for mk in state_dict.keys() if ("lora_A" in mk)]
print("lora_A:",len(merged_keys))
merged_keys = [mk for mk in state_dict.keys() if ("lora" in mk)]
print("lora:",len(merged_keys))
# breakpoint()
if adapter_config=="lora_adapter":
# neg_dict = {k:-v for k,v in state_dict.items() if "lora_A" in k}
neg_dict = {k:-1*scale*v for k,v in state_dict.items() if "lora_A" in k}

# ia3 (h+l*delta_h)-(h+delta_h)=(l-1)*delta_h h+delta_h-(l-1)*delta_h=h+(2-l)*delta_h
elif adapter_config=="ia3_adapter":
# neg_dict = {k:(torch.ones(v.shape)*2-v) for k,v in state_dict.items() if "lora" in k}
neg_dict = {k:(torch.ones(v.shape)*(1+scale)-scale*v) for k,v in state_dict.items() if "lora" in k}
else:
raise ValueError(f"adapter_config {adapter_config} not supported")

state_dict.update(neg_dict)
model.load_state_dict(state_dict)
# model.set_active_adapters(["civil_comments"])
# model.save_all_adapters(save_path)

return model



if __name__ == "__main__":
# Create the parser
parser = argparse.ArgumentParser()

parser.add_argument('--pretrained_model_path', type=str, required=True)
parser.add_argument('--finetuned_model_path', type=str, required=True)
parser.add_argument('--scaling_coef', type=float, required=True)
parser.add_argument('--model_save_path', type = str, required=True)
args = parser.parse_args()

pretrained_model= AutoModelForCausalLM.from_pretrained(args.pretrained_model_path)
finetuned_model= PeftModel.from_pretrained(pretrained_model, args.finetuned_model_path)
adapter_config=args.finetuned_model_path.split("-")[-1]#[:-1]
print(args.finetuned_model_path.split("-")[-1][:-1])
unbiased_model = adapter_negation(finetuned_model,adapter_config,args.scaling_coef)
print("unbiased_model created")
unbiased_model = unbiased_model.merge_and_unload()
print("unbiased_model merged")
unbiased_model.save_pretrained(args.model_save_path)
print("unbiased_model saved")


Loading

0 comments on commit 7cee731

Please sign in to comment.