generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
2,919 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Editing Models with Task Arithmetic | ||
|
||
This repository contains code for the ICLR 2023 paper [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089), by Gabriel Ilharco, Marco Tulio Ribeiro, Mitchell Wortsman, Suchin Gururangan, Ludwig Schmidt, Hannaneh Hajishirzi and Ali Farhadi. | ||
|
||
### Abstract | ||
*Changing how pre-trained models behave---e.g., improving their performance on a downstream task or mitigating biases learned during pre-training---is a common practice when developing machine learning systems. In this work, we propose a new paradigm for steering the behavior of neural networks, centered around task vectors. A task vector specifies a direction in the weight space of a pre-trained model, such that movement in that direction improves performance on the task. We build task vectors by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning on a task. We show that these task vectors can be modified and combined together through arithmetic operations such as negation and addition, and the behavior of the resulting model is steered accordingly. Negating a task vector decreases performance on the target task, with little change in model behavior on control tasks. Moreover, adding task vectors together can improve performance on multiple tasks at once. Finally, when tasks are linked by an analogy relationship of the form ``A is to B as C is to D", combining task vectors from three of the tasks can improve performance on the fourth, even when no data from the fourth task is used for training. Overall, our experiments with several models, modalities and tasks show that task arithmetic is a simple, efficient and effective way of editing models.* | ||
|
||
|
||
### Summary figure | ||
|
||
<p align="center"> | ||
<img src="img/task_vectors.png" alt="scatter" width="100%"/> | ||
</p> | ||
|
||
An illustration of task vectors and the arithmetic operations we study for editing models. (a) A task vector is obtained by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning. (b) Negating a task vector degrades performance on the task, without substantial changes in control tasks. (c) Adding task vectors together improves the performance of the pre-trained model on the tasks under consideration. (d) When tasks form an analogy relationship such as supervised and unsupervised learning on two different data sources, it is possible to improve performance on a supervised target task using only vectors from the remaining three combinations of objectives and datasets. | ||
|
||
## Code | ||
|
||
### Install dependencies | ||
|
||
```bash | ||
conda env create | ||
conda activate task-vectors | ||
``` | ||
|
||
|
||
### Add directory to PYTHONPATH: | ||
|
||
```bash | ||
cd task_vectors | ||
export PYTHONPATH="$PYTHONPATH:$PWD" | ||
``` | ||
|
||
### Using task vectors | ||
|
||
The task vector logic can be found at [src/task_vectors.py](src/task_vectors.py). | ||
|
||
To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint: | ||
|
||
```python | ||
from task_vectors import TaskVector | ||
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) | ||
``` | ||
|
||
Once created, task vectors can be modified and combined through arithmetic operations! For instance, to negate a task vector, simply use the ```-``` operator: | ||
|
||
```python | ||
# Negating a task vector | ||
new_task_vector = -task_vector | ||
``` | ||
|
||
To add task vectors, you can use the ```+``` operator, or ```sum```: | ||
|
||
```python | ||
# Adding two task vectors | ||
new_task_vector = task_vector_A + task_vector_B | ||
# Adding multiple task vectors | ||
new_task_vector = sum(list_of_task_vectors) | ||
``` | ||
|
||
Analogies can be done as simply as: | ||
|
||
```python | ||
# Task analogies | ||
new_task_vector = task_vector_C + task_vector_B - task_vector_A | ||
``` | ||
|
||
### Checkpoints | ||
|
||
Checkpoints for CLIP ViT-B/32, ViT-B/16 and ViT-L/14 are available on he link below, including fine-tuned checkpoints on eight downstream tasks: Stanford Cars, DTD, EuroSAT, GTSRB, MNIST, RESISC45, SUN397 and SVHN. | ||
|
||
[Download here](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw?usp=share_link) | ||
|
||
### Examples | ||
|
||
Below is an example of negating a task vector from MNIST, then evaluating on MNIST and on ImageNet: | ||
|
||
```python | ||
import torch | ||
from task_vectors import TaskVector | ||
from eval import eval_single_dataset | ||
from args import parse_arguments | ||
|
||
# Config | ||
dataset = 'MNIST' | ||
model = 'ViT-L-14' | ||
args = parse_arguments() | ||
args.data_location = '/path/to/data' | ||
args.model = model | ||
args.save = f'checkpoints/{model}' | ||
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt' | ||
finetuned_checkpoint = f'checkpoints/{model}/{dataset}/finetuned.pt' | ||
|
||
|
||
# Create the task vector | ||
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) | ||
# Negate the task vector | ||
neg_task_vector = -task_vector | ||
# Apply the task vector | ||
image_encoder = neg_task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.5) | ||
# Evaluate | ||
eval_single_dataset(image_encoder, dataset, args) | ||
eval_single_dataset(image_encoder, 'ImageNet', args) | ||
``` | ||
|
||
You can also find an example of adding task vectors together below, using the MNIST and RESISC45 datasets: | ||
|
||
|
||
```python | ||
import torch | ||
from task_vectors import TaskVector | ||
from eval import eval_single_dataset | ||
from args import parse_arguments | ||
|
||
# Config | ||
datasets = ['MNIST', 'RESISC45'] | ||
model = 'ViT-L-14' | ||
args = parse_arguments() | ||
args.data_location = '/path/to/data' | ||
args.model = model | ||
args.save = f'checkpoints/{model}' | ||
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt' | ||
|
||
# Create the task vectors | ||
task_vectors = [ | ||
TaskVector(pretrained_checkpoint, f'checkpoints/{model}/{dataset}/finetuned.pt') | ||
for dataset in datasets | ||
] | ||
# Sum the task vectors | ||
task_vector_sum = sum(task_vectors) | ||
# Apply the resulting task vector | ||
image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=0.8) | ||
# Evaluate | ||
for dataset in datasets: | ||
eval_single_dataset(image_encoder, dataset, args) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import transformers | ||
from datasets import load_dataset | ||
import sys | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM | ||
from peft import get_peft_model, LoraConfig, IA3Config, TaskType | ||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--model_name', type=str, required=True) | ||
parser.add_argument('--pem', type=str, default=None) | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
args.model_name, | ||
device_map='auto', | ||
) | ||
|
||
if 'Llama' in args.model_name and args.pem is not None: | ||
if args.pem == 'lora_adapter': | ||
llama_peft_config = LoraConfig( | ||
task_type="CAUSAL_LM", | ||
r=8, | ||
lora_dropout=0.01, | ||
) | ||
elif args.pem == 'ia3_adapter': | ||
llama_peft_config = IA3Config( | ||
peft_type="IA3", | ||
task_type="CAUSAL_LM", | ||
) | ||
else: | ||
raise ValueError("Invalid PEM type for Llama model.") | ||
model = get_peft_model(model, llama_peft_config) | ||
model.print_trainable_parameters() | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | ||
if 'Llama' in args.model_name: | ||
model.config.pad_token_id = model.config.eos_token_id | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
|
||
model.save_pretrained("./outputs/pretrained/" + args.model_name) | ||
tokenizer.save_pretrained("./outputs/pretrained/" + args.model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
name: task-vectors | ||
channels: | ||
- pytorch | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=main | ||
- _openmp_mutex=5.1=1_gnu | ||
- blas=1.0=mkl | ||
- brotlipy=0.7.0=py310h5764c6d_1004 | ||
- bzip2=1.0.8=h7b6447c_0 | ||
- ca-certificates=2022.9.24=ha878542_0 | ||
- cffi=1.15.1=py310h74dc2b5_0 | ||
- cryptography=37.0.2=py310h597c629_0 | ||
- cudatoolkit=11.6.0=hecad31d_10 | ||
- ffmpeg=4.3=hf484d3e_0 | ||
- freetype=2.10.4=h0708190_1 | ||
- giflib=5.2.1=h36c2ea0_2 | ||
- gmp=6.2.1=h58526e2_0 | ||
- gnutls=3.6.13=h85f3911_1 | ||
- intel-openmp=2021.4.0=h06a4308_3561 | ||
- jbig=2.1=h7f98852_2003 | ||
- jpeg=9e=h166bdaf_1 | ||
- lame=3.100=h7f98852_1001 | ||
- lcms2=2.12=hddcbb42_0 | ||
- ld_impl_linux-64=2.38=h1181459_1 | ||
- lerc=2.2.1=h9c3ff4c_0 | ||
- libdeflate=1.7=h7f98852_5 | ||
- libffi=3.3=he6710b0_2 | ||
- libgcc-ng=11.2.0=h1234567_1 | ||
- libgomp=11.2.0=h1234567_1 | ||
- libiconv=1.17=h166bdaf_0 | ||
- libpng=1.6.37=h21135ba_2 | ||
- libstdcxx-ng=11.2.0=h1234567_1 | ||
- libtiff=4.3.0=hf544144_1 | ||
- libuuid=1.0.3=h7f8727e_2 | ||
- libwebp=1.2.2=h3452ae3_0 | ||
- libwebp-base=1.2.2=h7f98852_1 | ||
- lz4-c=1.9.3=h9c3ff4c_1 | ||
- mkl=2021.4.0=h06a4308_640 | ||
- mkl-service=2.4.0=py310ha2c4b55_0 | ||
- mkl_fft=1.3.1=py310h2b4bcf5_1 | ||
- mkl_random=1.2.2=py310h00e6091_0 | ||
- ncurses=6.3=h5eee18b_3 | ||
- nettle=3.6=he412f7d_0 | ||
- numpy-base=1.23.1=py310hcba007f_0 | ||
- openh264=2.1.1=h780b84a_0 | ||
- openssl=1.1.1o=h166bdaf_0 | ||
- pip=22.2.2=py310h06a4308_0 | ||
- python=3.10.4=h12debd9_0 | ||
- python_abi=3.10=2_cp310 | ||
- pytorch=1.12.1=py3.10_cuda11.6_cudnn8.3.2_0 | ||
- pytorch-mutex=1.0=cuda | ||
- readline=8.1.2=h7f8727e_1 | ||
- setuptools=63.4.1=py310h06a4308_0 | ||
- sqlite=3.39.3=h5082296_0 | ||
- tk=8.6.12=h1ccaba5_0 | ||
- typing_extensions=4.3.0=pyha770c72_0 | ||
- tzdata=2022c=h04d1e81_0 | ||
- xz=5.2.6=h5eee18b_0 | ||
- zlib=1.2.12=h5eee18b_3 | ||
- zstd=1.5.0=ha95c52a_0 | ||
- pip: | ||
- certifi==2022.9.24 | ||
- charset-normalizer==2.1.1 | ||
- contourpy==1.0.5 | ||
- cvxpy==1.2.2 | ||
- cycler==0.11.0 | ||
- ecos==2.0.10 | ||
- filelock==3.8.0 | ||
- fonttools==4.37.4 | ||
- ftfy==6.1.1 | ||
- huggingface-hub==0.10.0 | ||
- idna==3.4 | ||
- kiwisolver==1.4.4 | ||
- matplotlib==3.6.0 | ||
- numpy==1.23.3 | ||
- open-clip-torch==2.0.2 | ||
- osqp==0.6.2.post5 | ||
- packaging==21.3 | ||
- pandas==1.5.0 | ||
- patsy==0.5.2 | ||
- pillow==9.2.0 | ||
- plotly==5.11.0 | ||
- pycparser==2.21 | ||
- pyopenssl==22.0.0 | ||
- pyparsing==3.0.9 | ||
- pysocks==1.7.1 | ||
- python-dateutil==2.8.2 | ||
- pytz==2022.4 | ||
- pyyaml==6.0 | ||
- qdldl==0.1.5.post2 | ||
- regex==2022.9.13 | ||
- requests==2.28.1 | ||
- scipy==1.9.1 | ||
- scs==3.2.2 | ||
- seaborn==0.12.1 | ||
- six==1.16.0 | ||
- statsmodels==0.13.2 | ||
- tenacity==8.1.0 | ||
- torch==1.12.1 | ||
- torchaudio==0.12.1+cu116 | ||
- torchvision==0.13.1 | ||
- tqdm==4.64.1 | ||
- typing-extensions==4.3.0 | ||
- urllib3==1.26.12 | ||
- wcwidth==0.2.5 | ||
- wheel==0.37.1 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import torch | ||
import argparse | ||
import os | ||
from peft import PeftConfig, PeftModel | ||
|
||
|
||
# import numpy as np | ||
# import torch | ||
# import random | ||
# import pandas as pd | ||
# import pdb | ||
from transformers import AutoModelForCausalLM | ||
|
||
def adapter_negation( | ||
model, | ||
adapter_config, | ||
scale | ||
): | ||
|
||
|
||
state_dict = model.state_dict() | ||
merged_keys = [mk for mk in state_dict.keys() if ("lora_A" in mk)] | ||
print("lora_A:",len(merged_keys)) | ||
merged_keys = [mk for mk in state_dict.keys() if ("lora" in mk)] | ||
print("lora:",len(merged_keys)) | ||
# breakpoint() | ||
if adapter_config=="lora_adapter": | ||
# neg_dict = {k:-v for k,v in state_dict.items() if "lora_A" in k} | ||
neg_dict = {k:-1*scale*v for k,v in state_dict.items() if "lora_A" in k} | ||
|
||
# ia3 (h+l*delta_h)-(h+delta_h)=(l-1)*delta_h h+delta_h-(l-1)*delta_h=h+(2-l)*delta_h | ||
elif adapter_config=="ia3_adapter": | ||
# neg_dict = {k:(torch.ones(v.shape)*2-v) for k,v in state_dict.items() if "lora" in k} | ||
neg_dict = {k:(torch.ones(v.shape)*(1+scale)-scale*v) for k,v in state_dict.items() if "lora" in k} | ||
else: | ||
raise ValueError(f"adapter_config {adapter_config} not supported") | ||
|
||
state_dict.update(neg_dict) | ||
model.load_state_dict(state_dict) | ||
# model.set_active_adapters(["civil_comments"]) | ||
# model.save_all_adapters(save_path) | ||
|
||
return model | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
# Create the parser | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--pretrained_model_path', type=str, required=True) | ||
parser.add_argument('--finetuned_model_path', type=str, required=True) | ||
parser.add_argument('--scaling_coef', type=float, required=True) | ||
parser.add_argument('--model_save_path', type = str, required=True) | ||
args = parser.parse_args() | ||
|
||
pretrained_model= AutoModelForCausalLM.from_pretrained(args.pretrained_model_path) | ||
finetuned_model= PeftModel.from_pretrained(pretrained_model, args.finetuned_model_path) | ||
adapter_config=args.finetuned_model_path.split("-")[-1]#[:-1] | ||
print(args.finetuned_model_path.split("-")[-1][:-1]) | ||
unbiased_model = adapter_negation(finetuned_model,adapter_config,args.scaling_coef) | ||
print("unbiased_model created") | ||
unbiased_model = unbiased_model.merge_and_unload() | ||
print("unbiased_model merged") | ||
unbiased_model.save_pretrained(args.model_save_path) | ||
print("unbiased_model saved") | ||
|
||
|
Oops, something went wrong.