diff --git a/experiments/continual_lora/configs/lora_sgd.yaml b/experiments/continual_lora/configs/lora_sgd.yaml index 9b9674bf..d6e2505f 100644 --- a/experiments/continual_lora/configs/lora_sgd.yaml +++ b/experiments/continual_lora/configs/lora_sgd.yaml @@ -23,9 +23,9 @@ model_config: &model_params dataset_path: "./experiments/continual_lora/data/pg19-5.json" train_batch_size: 2 laplace_batch_size: 1 -drop_last: true +drop_last: false train_proportion: 0.85 -shuffle: true +shuffle: false num_workers: 11 tokenizer_pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf" stride_length: 4096 diff --git a/experiments/continual_lora/run_continual_experiment.py b/experiments/continual_lora/run_continual_experiment.py index f75cd5b9..2d1619e0 100644 --- a/experiments/continual_lora/run_continual_experiment.py +++ b/experiments/continual_lora/run_continual_experiment.py @@ -76,17 +76,6 @@ def log_metrics(dir, current_epoch, episode_ind, val_ind, metrics): config, tokenizer, batch_size=config.laplace_batch_size ) - -# Total number of tokens to predict -def get_total_predict_tokens(data): - input_ids = data["input_ids"] - return input_ids.shape[0] * (input_ids.shape[1] - config.model_config.ignore_first) - - -train_total_predict_tokens = [ - get_total_predict_tokens(dl.dataset.data) for dl in train_dataloaders -] - # Load LoRA language model model = load_lora(config.model_config) model_func = uqlib.model_to_function(model) @@ -96,7 +85,7 @@ def get_total_predict_tokens(data): # Convert logits to log likelihood # scale_factor is required to rescale stochastic log likelihood so that it is # an unbiased estimate of the full log likelihood -def logits_to_log_lik(logits, labels, scale_factor): +def logits_to_log_lik(logits, labels): # pred_seq_len = seq_length - ignore_first logits_start = config.model_config.ignore_first logits = logits[ @@ -105,14 +94,15 @@ def logits_to_log_lik(logits, labels, scale_factor): labels = labels[:, logits_start:].contiguous() # (batch_size, pred_seq_len) log_lik = ( Categorical(logits=logits, validate_args=False).log_prob(labels).sum(1).mean() - ) # average gives single token expected log likelihood - return log_lik * scale_factor # rescale to full log likelihood + ) # sum over sequence, average over batch to give unbiased estimate of single sequence log likelihood + # (within sequence log likelihoods are not independent) + return log_lik # rescale to full log likelihood # Full parameter set log_likelihood -def param_to_log_likelihood(params, batch, scale_factor): +def param_to_log_likelihood(params, batch): output = model_func(params, labels=batch["input_ids"], **batch) - log_lik = logits_to_log_lik(output.logits, batch["input_ids"], scale_factor) + log_lik = logits_to_log_lik(output.logits, batch["input_ids"]) return log_lik, output @@ -134,10 +124,12 @@ def normal_log_prior(params, prior_mean, prior_sd) -> float: # Combine likelihod and prior (on LoRA params only) -def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor): - log_lik, output = sub_param_to_log_likelihood(p, batch, scale_factor) +# Gives unbiased estimate of log posterior / num_sequences +# Where num_sequences is the number of sequences in the training data +def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, num_sequences): + log_lik, output = sub_param_to_log_likelihood(p, batch) log_prior = normal_log_prior(p, prior_mean, prior_sd) - log_post = log_lik + log_prior + log_post = log_lik + log_prior / num_sequences return log_post, output @@ -163,13 +155,7 @@ def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor): sub_param_to_log_posterior, prior_mean=current_prior_mean, prior_sd=current_prior_sd, - scale_factor=train_total_predict_tokens[episode_ind], - ) - - # For validation we'll just assess the loss (expected token negative log likelihood) - test_sub_param_to_mean_log_lik = partial( - sub_param_to_log_likelihood, - scale_factor=1, + num_sequences=len(train_dataloaders[episode_ind].dataset.data["input_ids"]), ) # Train for MAP @@ -203,7 +189,7 @@ def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor): losses = [] for batch in test_dl: batch = optree.tree_map(lambda x: x.to(args.device), batch) - log_lik, output = test_sub_param_to_mean_log_lik(sub_params, batch) + log_lik, output = sub_param_to_log_likelihood(sub_params, batch) losses.append(-log_lik.item()) val_loss = sum(losses) / len(losses) @@ -226,17 +212,8 @@ def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor): if config.lambda_param > 0.0 and episode_ind < config.num_tasks - 1: print(f"Fitting Laplace on episode {episode_ind + 1} of {config.num_tasks}") - # For Laplace we need to rescale up by predicted sequence length, - # as uqlib already sums across batches - laplace_sub_param_to_log_likelihood = partial( - sub_param_to_log_likelihood, - scale_factor=config.stride_length - config.model_config.ignore_first, - ) - # Get Laplace precision diag - laplace_transform = uqlib.laplace.diag_fisher.build( - laplace_sub_param_to_log_likelihood - ) + laplace_transform = uqlib.laplace.diag_fisher.build(sub_param_to_log_likelihood) # Update Laplace state through data laplace_state = laplace_transform.init(sub_params)