Skip to content

Commit

Permalink
Merge pull request #11 from normal-computing/laplace-lora
Browse files Browse the repository at this point in the history
Laplace lora
  • Loading branch information
phoebeklett authored Feb 14, 2024
2 parents 0cf5dd7 + 1e01e0c commit 5de8c2f
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,7 @@ ENV/
/site

# mypy
.mypy_cache/
.mypy_cache/

# Experiment runs
experiments/runs/
164 changes: 164 additions & 0 deletions experiments/eval_lora_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
import wandb
from tqdm import tqdm
import pickle
from omegaconf import OmegaConf
import os
from ml_collections.config_dict import ConfigDict
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from experiments.laplace_lora import TransformerModule
from experiments.utils.utils import load_config


def evaluate(model, tuned, dataset):
results = []
avg_nlls_base = ()
avg_ppls_base = ()
avg_nlls_tuned = ()
avg_ppls_tuned = ()
for idx, sample in tqdm(enumerate(dataset)):
input_ids = sample["input_ids"].unsqueeze(0)
max_input_length = model.config.max_position_embeddings

sample_nlls_base = []
sample_nlls_tuned = []
prev_end_loc = 0
seq_len = input_ids.size(1)
for begin_loc in range(0, seq_len, 512):
end_loc = min(begin_loc + max_input_length, seq_len)
subseq = input_ids[:, begin_loc:end_loc]
targets = subseq.clone()
trg_len = end_loc - prev_end_loc
targets[:, :-trg_len] = -100

with torch.no_grad():
output_base = model(
input_ids=subseq.to(model.device),
labels=targets,
)
sample_nlls_base.append(output_base.loss)

output_tuned = tuned.model(
input_ids=subseq.to(tuned.model.device),
labels=targets,
)
sample_nlls_tuned.append(output_tuned.loss)

prev_end_loc = end_loc
if end_loc == seq_len:
break

sample_nlls_base = torch.tensor(sample_nlls_base)
sample_ppls_base = torch.exp(sample_nlls_base)

sample_avg_nll_base = torch.mean(sample_nlls_base)
sample_avg_ppl_base = torch.mean(sample_ppls_base)
wandb.log({"sample_avg_nll_base": sample_avg_nll_base})
wandb.log({"sample_avg_ppl_base": sample_avg_ppl_base})

sample_nlls_tuned = torch.tensor(sample_nlls_tuned)
sample_ppls_tuned = torch.exp(sample_nlls_tuned)

sample_avg_nll_tuned = torch.mean(sample_nlls_tuned)
sample_avg_ppl_tuned = torch.mean(sample_ppls_tuned)
wandb.log({"sample_avg_nll_tuned": sample_avg_nll_tuned})
wandb.log({"sample_avg_ppl_tuned": sample_avg_ppl_tuned})

results += [
{
"idx": idx,
"input_ids": sample["input_ids"],
"nlls_base": sample_nlls_base,
"nlls_tuned": sample_nlls_tuned,
"ppls_base": sample_ppls_base,
"ppls_tuned": sample_ppls_tuned,
"avg_nll_base": sample_avg_nll_base,
"avg_ppl_base": sample_avg_ppl_base,
"avg_nll_tuned": sample_avg_nll_tuned,
"avg_ppl_tuned": sample_avg_ppl_tuned,
}
]

avg_nlls_base += (sample_avg_nll_base,)
avg_ppls_base += (sample_avg_ppl_base,)

avg_nlls_tuned += (sample_avg_nll_tuned,)
avg_ppls_tuned += (sample_avg_ppl_tuned,)

avg_nll_base = torch.mean(torch.tensor(avg_nlls_base))
avg_ppl_base = torch.mean(torch.tensor(avg_ppls_base))

avg_nll_tuned = torch.mean(torch.tensor(avg_nlls_tuned))
avg_ppl_tuned = torch.mean(torch.tensor(avg_ppls_tuned))

wandb.log({"Avg NLL, Base Model": avg_nll_base})
wandb.log({"Avg PPL, Base Model": avg_ppl_base})

wandb.log({"Avg NLL, Tuned Model": avg_nll_tuned})
wandb.log({"Avg PPL, Tuned Model": avg_ppl_tuned})

return results


DATETIME = ""
EXPERIMENT_LOG_DIR = f"./experiments/runs/laplace_lora/{DATETIME}_laplace_lora"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LORA_WEIGHTS = EXPERIMENT_LOG_DIR + "/checkpoints/last.ckpt"
CONFIG = EXPERIMENT_LOG_DIR + "/laplace_lora.yaml"
config = ConfigDict(load_config(CONFIG))

model_tuned = TransformerModule.load_from_checkpoint(
LORA_WEIGHTS, config=config["model_config"]
).to(device)
model = AutoModelForCausalLM.from_pretrained(
config.model_config.pretrained_model_name_or_path
).to(device)
print("Weights loaded successfully!")

wandb.init(
project=config["experiment_name"],
dir=config.get("logs_dir", "logs"),
)
config.wandb_id_eval = wandb.run.id
config.wandb_name_eval = wandb.run.name

OmegaConf.save(
config=config.to_dict(),
f=EXPERIMENT_LOG_DIR + f"/{config['experiment_name']}.yaml",
)

dataset = load_dataset(config.dataset_name)
tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer_pretrained_model_name_or_path
)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
max_length=config.max_length,
truncation=True,
)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns([config.inputs_key])
tokenized_datasets.set_format("torch")

train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]

if config.small:
train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100))
eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100))

results = evaluate(model, model_tuned, eval_dataset)

result_file = os.path.join(EXPERIMENT_LOG_DIR, "results-eval.pkl")
with open(result_file, "wb") as f:
pickle.dump(results, f)
1 change: 1 addition & 0 deletions experiments/laplace_lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from experiments.laplace_lora.lora_transformer import TransformerModule
109 changes: 109 additions & 0 deletions experiments/laplace_lora/lora_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from itertools import groupby
from functools import partial
from optree import tree_map, tree_reduce
import lightning as L
import torch
from torch.optim import AdamW
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
from ml_collections.config_dict import FrozenConfigDict

import uqlib
from uqlib import model_to_function


class TransformerModule(L.LightningModule):
def __init__(self, config: FrozenConfigDict):
super().__init__()
self.automatic_optimization = False

self.pretrained_model_name_or_path = config.pretrained_model_name_or_path
self.prior_sd = config.prior_sd
self.lr = config.lr

self.target_modules = config.lora_config.target_modules
self.r = config.lora_config.r
self.alpha = config.lora_config.alpha
self.dropout = config.lora_config.dropout

model = AutoModelForCausalLM.from_pretrained(self.pretrained_model_name_or_path)
# only adapt W_q, W_v, W_o
# regex may not work for all models

WEIGHTS_TO_LORA = ["q_proj", "v_proj", "o_proj"]

modules = list(model.model.layers.named_parameters())
# Get layer index, name for layers to adapt
module_names_with_layer = [
(name.split(".")[0], f'layer.{name.strip('.weight')}')
for name, param in modules
if any(
sub in name
for sub in [
"self_attn.{sub}".format(sub=sub) for sub in WEIGHTS_TO_LORA
]
)
]

# Subset of layers to adapt
if self.target_modules == "last_layer":
modules = [
[layer for name, layer in list(group)]
for _, group in groupby(module_names_with_layer, key=lambda x: x[0])
][-1]
else:
modules = [name for layer, name in module_names_with_layer]

peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=modules,
r=self.r,
lora_alpha=self.alpha,
lora_dropout=self.dropout,
)

self.model = get_peft_model(model, peft_config)
self.model.print_trainable_parameters()
self.model_func = model_to_function(self.model)

@staticmethod
def univariate_normal_log_prob(x, mean, sd):
return -0.5 * ((x - mean) / sd) ** 2

def normal_log_prior(self, p) -> float:
per_group_vals = tree_map(
lambda p: self.univariate_normal_log_prob(p, 0, self.prior_sd).sum(), p
)
return tree_reduce(torch.add, per_group_vals)

def param_to_log_posterior(self, p, batch, num_data) -> torch.tensor:
output = self.model_func(p, labels=batch["input_ids"], **batch)
return (-output.loss) + self.normal_log_prior(p) / num_data, output

def on_train_start(self) -> None:
param_to_log_posterior = partial(
self.param_to_log_posterior,
num_data=len(self.trainer.train_dataloader.dataset),
)

(
self.sub_params,
self.sub_param_to_log_posterior,
) = uqlib.extract_requires_grad_and_func(
dict(self.model.named_parameters()), param_to_log_posterior
)
self.opt = AdamW(self.sub_params.values(), lr=self.lr, maximize=True)

def configure_optimizers(self):
pass

def training_step(self, batch, batch_idx):
self.opt.zero_grad()

log_post, out = self.sub_param_to_log_posterior(self.sub_params, batch)
log_post.backward()

self.log("log_post", log_post.item())
self.opt.step()

return log_post
Loading

0 comments on commit 5de8c2f

Please sign in to comment.