diff --git a/.gitignore b/.gitignore index 9dd3e441..67e02322 100644 --- a/.gitignore +++ b/.gitignore @@ -102,4 +102,7 @@ ENV/ /site # mypy -.mypy_cache/ \ No newline at end of file +.mypy_cache/ + +# Experiment runs +experiments/runs/ \ No newline at end of file diff --git a/experiments/eval_lora_transformer.py b/experiments/eval_lora_transformer.py new file mode 100644 index 00000000..7ceec0b7 --- /dev/null +++ b/experiments/eval_lora_transformer.py @@ -0,0 +1,164 @@ +import torch +import wandb +from tqdm import tqdm +import pickle +from omegaconf import OmegaConf +import os +from ml_collections.config_dict import ConfigDict +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from experiments.laplace_lora import TransformerModule +from experiments.utils.utils import load_config + + +def evaluate(model, tuned, dataset): + results = [] + avg_nlls_base = () + avg_ppls_base = () + avg_nlls_tuned = () + avg_ppls_tuned = () + for idx, sample in tqdm(enumerate(dataset)): + input_ids = sample["input_ids"].unsqueeze(0) + max_input_length = model.config.max_position_embeddings + + sample_nlls_base = [] + sample_nlls_tuned = [] + prev_end_loc = 0 + seq_len = input_ids.size(1) + for begin_loc in range(0, seq_len, 512): + end_loc = min(begin_loc + max_input_length, seq_len) + subseq = input_ids[:, begin_loc:end_loc] + targets = subseq.clone() + trg_len = end_loc - prev_end_loc + targets[:, :-trg_len] = -100 + + with torch.no_grad(): + output_base = model( + input_ids=subseq.to(model.device), + labels=targets, + ) + sample_nlls_base.append(output_base.loss) + + output_tuned = tuned.model( + input_ids=subseq.to(tuned.model.device), + labels=targets, + ) + sample_nlls_tuned.append(output_tuned.loss) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + sample_nlls_base = torch.tensor(sample_nlls_base) + sample_ppls_base = torch.exp(sample_nlls_base) + + sample_avg_nll_base = torch.mean(sample_nlls_base) + sample_avg_ppl_base = torch.mean(sample_ppls_base) + wandb.log({"sample_avg_nll_base": sample_avg_nll_base}) + wandb.log({"sample_avg_ppl_base": sample_avg_ppl_base}) + + sample_nlls_tuned = torch.tensor(sample_nlls_tuned) + sample_ppls_tuned = torch.exp(sample_nlls_tuned) + + sample_avg_nll_tuned = torch.mean(sample_nlls_tuned) + sample_avg_ppl_tuned = torch.mean(sample_ppls_tuned) + wandb.log({"sample_avg_nll_tuned": sample_avg_nll_tuned}) + wandb.log({"sample_avg_ppl_tuned": sample_avg_ppl_tuned}) + + results += [ + { + "idx": idx, + "input_ids": sample["input_ids"], + "nlls_base": sample_nlls_base, + "nlls_tuned": sample_nlls_tuned, + "ppls_base": sample_ppls_base, + "ppls_tuned": sample_ppls_tuned, + "avg_nll_base": sample_avg_nll_base, + "avg_ppl_base": sample_avg_ppl_base, + "avg_nll_tuned": sample_avg_nll_tuned, + "avg_ppl_tuned": sample_avg_ppl_tuned, + } + ] + + avg_nlls_base += (sample_avg_nll_base,) + avg_ppls_base += (sample_avg_ppl_base,) + + avg_nlls_tuned += (sample_avg_nll_tuned,) + avg_ppls_tuned += (sample_avg_ppl_tuned,) + + avg_nll_base = torch.mean(torch.tensor(avg_nlls_base)) + avg_ppl_base = torch.mean(torch.tensor(avg_ppls_base)) + + avg_nll_tuned = torch.mean(torch.tensor(avg_nlls_tuned)) + avg_ppl_tuned = torch.mean(torch.tensor(avg_ppls_tuned)) + + wandb.log({"Avg NLL, Base Model": avg_nll_base}) + wandb.log({"Avg PPL, Base Model": avg_ppl_base}) + + wandb.log({"Avg NLL, Tuned Model": avg_nll_tuned}) + wandb.log({"Avg PPL, Tuned Model": avg_ppl_tuned}) + + return results + + +DATETIME = "" +EXPERIMENT_LOG_DIR = f"./experiments/runs/laplace_lora/{DATETIME}_laplace_lora" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LORA_WEIGHTS = EXPERIMENT_LOG_DIR + "/checkpoints/last.ckpt" + CONFIG = EXPERIMENT_LOG_DIR + "/laplace_lora.yaml" + config = ConfigDict(load_config(CONFIG)) + + model_tuned = TransformerModule.load_from_checkpoint( + LORA_WEIGHTS, config=config["model_config"] + ).to(device) + model = AutoModelForCausalLM.from_pretrained( + config.model_config.pretrained_model_name_or_path + ).to(device) + print("Weights loaded successfully!") + + wandb.init( + project=config["experiment_name"], + dir=config.get("logs_dir", "logs"), + ) + config.wandb_id_eval = wandb.run.id + config.wandb_name_eval = wandb.run.name + + OmegaConf.save( + config=config.to_dict(), + f=EXPERIMENT_LOG_DIR + f"/{config['experiment_name']}.yaml", + ) + + dataset = load_dataset(config.dataset_name) + tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer_pretrained_model_name_or_path + ) + tokenizer.pad_token = tokenizer.eos_token + + def tokenize_function(examples): + return tokenizer( + examples["text"], + padding="max_length", + max_length=config.max_length, + truncation=True, + ) + + tokenized_datasets = dataset.map(tokenize_function, batched=True) + tokenized_datasets = tokenized_datasets.remove_columns([config.inputs_key]) + tokenized_datasets.set_format("torch") + + train_dataset = tokenized_datasets["train"] + eval_dataset = tokenized_datasets["test"] + + if config.small: + train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100)) + eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100)) + + results = evaluate(model, model_tuned, eval_dataset) + + result_file = os.path.join(EXPERIMENT_LOG_DIR, "results-eval.pkl") + with open(result_file, "wb") as f: + pickle.dump(results, f) diff --git a/experiments/laplace_lora/__init__.py b/experiments/laplace_lora/__init__.py new file mode 100644 index 00000000..13ea9aab --- /dev/null +++ b/experiments/laplace_lora/__init__.py @@ -0,0 +1 @@ +from experiments.laplace_lora.lora_transformer import TransformerModule diff --git a/experiments/laplace_lora/lora_transformer.py b/experiments/laplace_lora/lora_transformer.py new file mode 100644 index 00000000..db067d47 --- /dev/null +++ b/experiments/laplace_lora/lora_transformer.py @@ -0,0 +1,109 @@ +from itertools import groupby +from functools import partial +from optree import tree_map, tree_reduce +import lightning as L +import torch +from torch.optim import AdamW +from transformers import AutoModelForCausalLM +from peft import LoraConfig, TaskType, get_peft_model +from ml_collections.config_dict import FrozenConfigDict + +import uqlib +from uqlib import model_to_function + + +class TransformerModule(L.LightningModule): + def __init__(self, config: FrozenConfigDict): + super().__init__() + self.automatic_optimization = False + + self.pretrained_model_name_or_path = config.pretrained_model_name_or_path + self.prior_sd = config.prior_sd + self.lr = config.lr + + self.target_modules = config.lora_config.target_modules + self.r = config.lora_config.r + self.alpha = config.lora_config.alpha + self.dropout = config.lora_config.dropout + + model = AutoModelForCausalLM.from_pretrained(self.pretrained_model_name_or_path) + # only adapt W_q, W_v, W_o + # regex may not work for all models + + WEIGHTS_TO_LORA = ["q_proj", "v_proj", "o_proj"] + + modules = list(model.model.layers.named_parameters()) + # Get layer index, name for layers to adapt + module_names_with_layer = [ + (name.split(".")[0], f'layer.{name.strip('.weight')}') + for name, param in modules + if any( + sub in name + for sub in [ + "self_attn.{sub}".format(sub=sub) for sub in WEIGHTS_TO_LORA + ] + ) + ] + + # Subset of layers to adapt + if self.target_modules == "last_layer": + modules = [ + [layer for name, layer in list(group)] + for _, group in groupby(module_names_with_layer, key=lambda x: x[0]) + ][-1] + else: + modules = [name for layer, name in module_names_with_layer] + + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + target_modules=modules, + r=self.r, + lora_alpha=self.alpha, + lora_dropout=self.dropout, + ) + + self.model = get_peft_model(model, peft_config) + self.model.print_trainable_parameters() + self.model_func = model_to_function(self.model) + + @staticmethod + def univariate_normal_log_prob(x, mean, sd): + return -0.5 * ((x - mean) / sd) ** 2 + + def normal_log_prior(self, p) -> float: + per_group_vals = tree_map( + lambda p: self.univariate_normal_log_prob(p, 0, self.prior_sd).sum(), p + ) + return tree_reduce(torch.add, per_group_vals) + + def param_to_log_posterior(self, p, batch, num_data) -> torch.tensor: + output = self.model_func(p, labels=batch["input_ids"], **batch) + return (-output.loss) + self.normal_log_prior(p) / num_data, output + + def on_train_start(self) -> None: + param_to_log_posterior = partial( + self.param_to_log_posterior, + num_data=len(self.trainer.train_dataloader.dataset), + ) + + ( + self.sub_params, + self.sub_param_to_log_posterior, + ) = uqlib.extract_requires_grad_and_func( + dict(self.model.named_parameters()), param_to_log_posterior + ) + self.opt = AdamW(self.sub_params.values(), lr=self.lr, maximize=True) + + def configure_optimizers(self): + pass + + def training_step(self, batch, batch_idx): + self.opt.zero_grad() + + log_post, out = self.sub_param_to_log_posterior(self.sub_params, batch) + log_post.backward() + + self.log("log_post", log_post.item()) + self.opt.step() + + return log_post diff --git a/experiments/run_laplace_lora.py b/experiments/run_laplace_lora.py new file mode 100644 index 00000000..94014ff2 --- /dev/null +++ b/experiments/run_laplace_lora.py @@ -0,0 +1,131 @@ +import argparse +import os +import glob +import datetime +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger +from datasets import load_dataset +from transformers import AutoTokenizer +from ml_collections import ConfigDict + +from experiments.utils import parse_devices, load_config, save_config, setup_log_dir +from experiments.laplace_lora import TransformerModule + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +parser = argparse.ArgumentParser() +parser.add_argument("--name", default=None, type=str) +parser.add_argument("--resume", default=None, type=str) +parser.add_argument("--base", default=None, type=str) +parser.add_argument("--devices", default=parse_devices, type=str) +parser.add_argument("--epochs", default=100, type=int) +parser.add_argument("--log_frequency", default=10, type=int) +parser.add_argument("--seed", default=42, type=int) + +args = parser.parse_args() + +if __name__ == "__main__": + device_type = "cpu" if callable(args.devices) else "gpu" + if args.resume is None: + assert ( + args.base is not None + ), "Configs not specified, specify at least resume or base" + config = load_config(args.base) + else: + assert os.path.exists( + args.resume + ), "Provided path to resume training does not exist" + config_paths = glob.glob(os.path.join(args.resume, "*.yaml")) + assert len(config_paths) == 1, "Too many possible configs to resume from" + config = load_config(config_paths[0]) + + timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + experiment_name = config.get("experiment_name", None) + + experiment_log_dir = setup_log_dir( + config.get("logs_dir", "logs"), + timestamp, + resume=args.resume, + experiment_name=experiment_name, + ) + + torch.set_float32_matmul_precision("medium") + torch.manual_seed(args.seed) + + trainer_kwargs = { + "max_epochs": args.epochs, + "accelerator": device_type, + "log_every_n_steps": args.log_frequency, + } + + model = TransformerModule(config.model_config) + + config = ConfigDict(config) # thaw + logger = WandbLogger( + log_model="all", + project=config.get("experiment_name", ""), + save_dir=config.get("logs_dir", "logs"), + ) + config["wandb_name"] = logger.experiment.name + config["wandb_id"] = logger.experiment.id + + config["epochs"] = args.epochs + config["log_frequency"] = args.log_frequency + config["seed"] = args.seed + + if args.resume is None: + save_config( + config.to_dict(), f"{experiment_log_dir}/{os.path.basename(args.base)}" + ) + + trainer = Trainer(**trainer_kwargs, logger=logger) + + tokenizer = AutoTokenizer.from_pretrained( + config.model_config.pretrained_model_name_or_path + ) + tokenizer.pad_token = tokenizer.eos_token + model = TransformerModule(config.model_config) + + dataset = load_dataset(config.dataset_name) + + def tokenize_function(examples): + return tokenizer( + examples["text"], + padding="max_length", + max_length=config.max_length, + truncation=True, + ) + + tokenized_datasets = dataset.map(tokenize_function, batched=True) + tokenized_datasets = tokenized_datasets.remove_columns([config.inputs_key]) + tokenized_datasets.set_format("torch") + + train_dataset = tokenized_datasets["train"] + eval_dataset = tokenized_datasets["test"] + if config.small: + train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100)) + eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100)) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + test_dataloader = torch.utils.data.DataLoader( + eval_dataset, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + try: + resume_ckpt = None + if args.resume is not None: + resume_ckpt = os.path.join(args.resume, "checkpoints", "last.ckpt") + trainer.fit(model, train_dataloader, ckpt_path=resume_ckpt) + finally: + if trainer.global_rank == 0: + final_ckpt = os.path.join(experiment_log_dir, "checkpoints", "last.ckpt") + trainer.save_checkpoint(final_ckpt) diff --git a/experiments/utils/__init__.py b/experiments/utils/__init__.py index 499b2f03..b51202d9 100644 --- a/experiments/utils/__init__.py +++ b/experiments/utils/__init__.py @@ -1 +1 @@ -from .utils import parse_devices +from .utils import parse_devices, load_config, save_config, setup_log_dir diff --git a/experiments/utils/configs/laplace_lora.yaml b/experiments/utils/configs/laplace_lora.yaml new file mode 100644 index 00000000..818df614 --- /dev/null +++ b/experiments/utils/configs/laplace_lora.yaml @@ -0,0 +1,29 @@ +# File dirs +base_dir: &base_path "./experiments/" +logs_dir: &logs_path "./experiments/runs/laplace_lora/" +data_dir: &data_path "./experiments/laplace_lora/data/" + +experiment_name: "laplace_lora" + +# Model +model_config: &model_params + pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf" + prior_sd: 1.0 + lr: 0.00001 + + #LoRA + lora_config: &lora_params + target_modules: "last_layer" + r: 8 + alpha: 32 + dropout: 0.0 + +# Dataset +dataset_name: "timdettmers/openassistant-guanaco" +batch_size: 4 +small: True +num_workers: 11 +tokenizer_pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf" +max_length: 4096 +truncation: True +inputs_key: "text" diff --git a/experiments/utils/utils.py b/experiments/utils/utils.py index 99c09f42..09d87a7e 100644 --- a/experiments/utils/utils.py +++ b/experiments/utils/utils.py @@ -1,6 +1,10 @@ +import os from typing import List import torch from torch import nn +from omegaconf import OmegaConf +from pytorch_lightning.utilities import rank_zero_only +from ml_collections.config_dict import FrozenConfigDict def parse_devices(devices): @@ -30,3 +34,66 @@ def load_optimizer_param_to_model(model: nn.Module, groups: List[List[torch.Tens for model_param, optimizer_param in zip(list(model.parameters()), optimizer_params): model_param.data = optimizer_param + + +REQUIRED_PARAMS = ["model_config", "experiment_name"] + + +def load_config(file: str) -> FrozenConfigDict: + """ + Load config file + """ + config = OmegaConf.load(file) + for param in REQUIRED_PARAMS: + assert param in config, f"Missing key {param} in config" + + config = FrozenConfigDict(config) + return config + + +@rank_zero_only +def save_config(conf: OmegaConf, fp: str): + """ + Save config file, only once + """ + OmegaConf.save(config=conf, f=fp) + + +@rank_zero_only +def create_log_dir(log_dir_name: str): + """ + Create log directory, only once + """ + if not os.path.exists(log_dir_name): + os.mkdir(log_dir_name) + + +def setup_log_dir( + log_dir_name: str, + timestamp: str, + resume: bool = False, + experiment_name: str = None, +) -> str: + """ + Setup log directory + """ + if resume: + return resume + + # Create parent log name + if not os.path.exists(log_dir_name): + os.mkdir(log_dir_name) + + # Create timestamp folder + log_dir_name = os.path.join(log_dir_name, timestamp) + + # Add experiment name if specified + if experiment_name is not None: + log_dir_name += f"_{experiment_name}" + + create_log_dir(log_dir_name) + + # Create checkpoints folder + create_log_dir(f"{log_dir_name}/checkpoints") + + return log_dir_name