-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from normal-computing/laplace-lora
Laplace lora
- Loading branch information
Showing
8 changed files
with
506 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,4 +102,7 @@ ENV/ | |
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.mypy_cache/ | ||
|
||
# Experiment runs | ||
experiments/runs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import torch | ||
import wandb | ||
from tqdm import tqdm | ||
import pickle | ||
from omegaconf import OmegaConf | ||
import os | ||
from ml_collections.config_dict import ConfigDict | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from experiments.laplace_lora import TransformerModule | ||
from experiments.utils.utils import load_config | ||
|
||
|
||
def evaluate(model, tuned, dataset): | ||
results = [] | ||
avg_nlls_base = () | ||
avg_ppls_base = () | ||
avg_nlls_tuned = () | ||
avg_ppls_tuned = () | ||
for idx, sample in tqdm(enumerate(dataset)): | ||
input_ids = sample["input_ids"].unsqueeze(0) | ||
max_input_length = model.config.max_position_embeddings | ||
|
||
sample_nlls_base = [] | ||
sample_nlls_tuned = [] | ||
prev_end_loc = 0 | ||
seq_len = input_ids.size(1) | ||
for begin_loc in range(0, seq_len, 512): | ||
end_loc = min(begin_loc + max_input_length, seq_len) | ||
subseq = input_ids[:, begin_loc:end_loc] | ||
targets = subseq.clone() | ||
trg_len = end_loc - prev_end_loc | ||
targets[:, :-trg_len] = -100 | ||
|
||
with torch.no_grad(): | ||
output_base = model( | ||
input_ids=subseq.to(model.device), | ||
labels=targets, | ||
) | ||
sample_nlls_base.append(output_base.loss) | ||
|
||
output_tuned = tuned.model( | ||
input_ids=subseq.to(tuned.model.device), | ||
labels=targets, | ||
) | ||
sample_nlls_tuned.append(output_tuned.loss) | ||
|
||
prev_end_loc = end_loc | ||
if end_loc == seq_len: | ||
break | ||
|
||
sample_nlls_base = torch.tensor(sample_nlls_base) | ||
sample_ppls_base = torch.exp(sample_nlls_base) | ||
|
||
sample_avg_nll_base = torch.mean(sample_nlls_base) | ||
sample_avg_ppl_base = torch.mean(sample_ppls_base) | ||
wandb.log({"sample_avg_nll_base": sample_avg_nll_base}) | ||
wandb.log({"sample_avg_ppl_base": sample_avg_ppl_base}) | ||
|
||
sample_nlls_tuned = torch.tensor(sample_nlls_tuned) | ||
sample_ppls_tuned = torch.exp(sample_nlls_tuned) | ||
|
||
sample_avg_nll_tuned = torch.mean(sample_nlls_tuned) | ||
sample_avg_ppl_tuned = torch.mean(sample_ppls_tuned) | ||
wandb.log({"sample_avg_nll_tuned": sample_avg_nll_tuned}) | ||
wandb.log({"sample_avg_ppl_tuned": sample_avg_ppl_tuned}) | ||
|
||
results += [ | ||
{ | ||
"idx": idx, | ||
"input_ids": sample["input_ids"], | ||
"nlls_base": sample_nlls_base, | ||
"nlls_tuned": sample_nlls_tuned, | ||
"ppls_base": sample_ppls_base, | ||
"ppls_tuned": sample_ppls_tuned, | ||
"avg_nll_base": sample_avg_nll_base, | ||
"avg_ppl_base": sample_avg_ppl_base, | ||
"avg_nll_tuned": sample_avg_nll_tuned, | ||
"avg_ppl_tuned": sample_avg_ppl_tuned, | ||
} | ||
] | ||
|
||
avg_nlls_base += (sample_avg_nll_base,) | ||
avg_ppls_base += (sample_avg_ppl_base,) | ||
|
||
avg_nlls_tuned += (sample_avg_nll_tuned,) | ||
avg_ppls_tuned += (sample_avg_ppl_tuned,) | ||
|
||
avg_nll_base = torch.mean(torch.tensor(avg_nlls_base)) | ||
avg_ppl_base = torch.mean(torch.tensor(avg_ppls_base)) | ||
|
||
avg_nll_tuned = torch.mean(torch.tensor(avg_nlls_tuned)) | ||
avg_ppl_tuned = torch.mean(torch.tensor(avg_ppls_tuned)) | ||
|
||
wandb.log({"Avg NLL, Base Model": avg_nll_base}) | ||
wandb.log({"Avg PPL, Base Model": avg_ppl_base}) | ||
|
||
wandb.log({"Avg NLL, Tuned Model": avg_nll_tuned}) | ||
wandb.log({"Avg PPL, Tuned Model": avg_ppl_tuned}) | ||
|
||
return results | ||
|
||
|
||
DATETIME = "" | ||
EXPERIMENT_LOG_DIR = f"./experiments/runs/laplace_lora/{DATETIME}_laplace_lora" | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
if __name__ == "__main__": | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
LORA_WEIGHTS = EXPERIMENT_LOG_DIR + "/checkpoints/last.ckpt" | ||
CONFIG = EXPERIMENT_LOG_DIR + "/laplace_lora.yaml" | ||
config = ConfigDict(load_config(CONFIG)) | ||
|
||
model_tuned = TransformerModule.load_from_checkpoint( | ||
LORA_WEIGHTS, config=config["model_config"] | ||
).to(device) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
config.model_config.pretrained_model_name_or_path | ||
).to(device) | ||
print("Weights loaded successfully!") | ||
|
||
wandb.init( | ||
project=config["experiment_name"], | ||
dir=config.get("logs_dir", "logs"), | ||
) | ||
config.wandb_id_eval = wandb.run.id | ||
config.wandb_name_eval = wandb.run.name | ||
|
||
OmegaConf.save( | ||
config=config.to_dict(), | ||
f=EXPERIMENT_LOG_DIR + f"/{config['experiment_name']}.yaml", | ||
) | ||
|
||
dataset = load_dataset(config.dataset_name) | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
config.tokenizer_pretrained_model_name_or_path | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
def tokenize_function(examples): | ||
return tokenizer( | ||
examples["text"], | ||
padding="max_length", | ||
max_length=config.max_length, | ||
truncation=True, | ||
) | ||
|
||
tokenized_datasets = dataset.map(tokenize_function, batched=True) | ||
tokenized_datasets = tokenized_datasets.remove_columns([config.inputs_key]) | ||
tokenized_datasets.set_format("torch") | ||
|
||
train_dataset = tokenized_datasets["train"] | ||
eval_dataset = tokenized_datasets["test"] | ||
|
||
if config.small: | ||
train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100)) | ||
eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100)) | ||
|
||
results = evaluate(model, model_tuned, eval_dataset) | ||
|
||
result_file = os.path.join(EXPERIMENT_LOG_DIR, "results-eval.pkl") | ||
with open(result_file, "wb") as f: | ||
pickle.dump(results, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from experiments.laplace_lora.lora_transformer import TransformerModule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from itertools import groupby | ||
from functools import partial | ||
from optree import tree_map, tree_reduce | ||
import lightning as L | ||
import torch | ||
from torch.optim import AdamW | ||
from transformers import AutoModelForCausalLM | ||
from peft import LoraConfig, TaskType, get_peft_model | ||
from ml_collections.config_dict import FrozenConfigDict | ||
|
||
import uqlib | ||
from uqlib import model_to_function | ||
|
||
|
||
class TransformerModule(L.LightningModule): | ||
def __init__(self, config: FrozenConfigDict): | ||
super().__init__() | ||
self.automatic_optimization = False | ||
|
||
self.pretrained_model_name_or_path = config.pretrained_model_name_or_path | ||
self.prior_sd = config.prior_sd | ||
self.lr = config.lr | ||
|
||
self.target_modules = config.lora_config.target_modules | ||
self.r = config.lora_config.r | ||
self.alpha = config.lora_config.alpha | ||
self.dropout = config.lora_config.dropout | ||
|
||
model = AutoModelForCausalLM.from_pretrained(self.pretrained_model_name_or_path) | ||
# only adapt W_q, W_v, W_o | ||
# regex may not work for all models | ||
|
||
WEIGHTS_TO_LORA = ["q_proj", "v_proj", "o_proj"] | ||
|
||
modules = list(model.model.layers.named_parameters()) | ||
# Get layer index, name for layers to adapt | ||
module_names_with_layer = [ | ||
(name.split(".")[0], f'layer.{name.strip('.weight')}') | ||
for name, param in modules | ||
if any( | ||
sub in name | ||
for sub in [ | ||
"self_attn.{sub}".format(sub=sub) for sub in WEIGHTS_TO_LORA | ||
] | ||
) | ||
] | ||
|
||
# Subset of layers to adapt | ||
if self.target_modules == "last_layer": | ||
modules = [ | ||
[layer for name, layer in list(group)] | ||
for _, group in groupby(module_names_with_layer, key=lambda x: x[0]) | ||
][-1] | ||
else: | ||
modules = [name for layer, name in module_names_with_layer] | ||
|
||
peft_config = LoraConfig( | ||
task_type=TaskType.CAUSAL_LM, | ||
target_modules=modules, | ||
r=self.r, | ||
lora_alpha=self.alpha, | ||
lora_dropout=self.dropout, | ||
) | ||
|
||
self.model = get_peft_model(model, peft_config) | ||
self.model.print_trainable_parameters() | ||
self.model_func = model_to_function(self.model) | ||
|
||
@staticmethod | ||
def univariate_normal_log_prob(x, mean, sd): | ||
return -0.5 * ((x - mean) / sd) ** 2 | ||
|
||
def normal_log_prior(self, p) -> float: | ||
per_group_vals = tree_map( | ||
lambda p: self.univariate_normal_log_prob(p, 0, self.prior_sd).sum(), p | ||
) | ||
return tree_reduce(torch.add, per_group_vals) | ||
|
||
def param_to_log_posterior(self, p, batch, num_data) -> torch.tensor: | ||
output = self.model_func(p, labels=batch["input_ids"], **batch) | ||
return (-output.loss) + self.normal_log_prior(p) / num_data, output | ||
|
||
def on_train_start(self) -> None: | ||
param_to_log_posterior = partial( | ||
self.param_to_log_posterior, | ||
num_data=len(self.trainer.train_dataloader.dataset), | ||
) | ||
|
||
( | ||
self.sub_params, | ||
self.sub_param_to_log_posterior, | ||
) = uqlib.extract_requires_grad_and_func( | ||
dict(self.model.named_parameters()), param_to_log_posterior | ||
) | ||
self.opt = AdamW(self.sub_params.values(), lr=self.lr, maximize=True) | ||
|
||
def configure_optimizers(self): | ||
pass | ||
|
||
def training_step(self, batch, batch_idx): | ||
self.opt.zero_grad() | ||
|
||
log_post, out = self.sub_param_to_log_posterior(self.sub_params, batch) | ||
log_post.backward() | ||
|
||
self.log("log_post", log_post.item()) | ||
self.opt.step() | ||
|
||
return log_post |
Oops, something went wrong.