Skip to content

Commit

Permalink
Update scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield committed Feb 23, 2024
1 parent 9072b43 commit a3155cb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 39 deletions.
4 changes: 2 additions & 2 deletions experiments/continual_lora/configs/lora_sgd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ model_config: &model_params
dataset_path: "./experiments/continual_lora/data/pg19-5.json"
train_batch_size: 2
laplace_batch_size: 1
drop_last: true
drop_last: false
train_proportion: 0.85
shuffle: true
shuffle: false
num_workers: 11
tokenizer_pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf"
stride_length: 4096
Expand Down
51 changes: 14 additions & 37 deletions experiments/continual_lora/run_continual_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,6 @@ def log_metrics(dir, current_epoch, episode_ind, val_ind, metrics):
config, tokenizer, batch_size=config.laplace_batch_size
)


# Total number of tokens to predict
def get_total_predict_tokens(data):
input_ids = data["input_ids"]
return input_ids.shape[0] * (input_ids.shape[1] - config.model_config.ignore_first)


train_total_predict_tokens = [
get_total_predict_tokens(dl.dataset.data) for dl in train_dataloaders
]

# Load LoRA language model
model = load_lora(config.model_config)
model_func = uqlib.model_to_function(model)
Expand All @@ -96,7 +85,7 @@ def get_total_predict_tokens(data):
# Convert logits to log likelihood
# scale_factor is required to rescale stochastic log likelihood so that it is
# an unbiased estimate of the full log likelihood
def logits_to_log_lik(logits, labels, scale_factor):
def logits_to_log_lik(logits, labels):
# pred_seq_len = seq_length - ignore_first
logits_start = config.model_config.ignore_first
logits = logits[
Expand All @@ -105,14 +94,15 @@ def logits_to_log_lik(logits, labels, scale_factor):
labels = labels[:, logits_start:].contiguous() # (batch_size, pred_seq_len)
log_lik = (
Categorical(logits=logits, validate_args=False).log_prob(labels).sum(1).mean()
) # average gives single token expected log likelihood
return log_lik * scale_factor # rescale to full log likelihood
) # sum over sequence, average over batch to give unbiased estimate of single sequence log likelihood
# (within sequence log likelihoods are not independent)
return log_lik # rescale to full log likelihood


# Full parameter set log_likelihood
def param_to_log_likelihood(params, batch, scale_factor):
def param_to_log_likelihood(params, batch):
output = model_func(params, labels=batch["input_ids"], **batch)
log_lik = logits_to_log_lik(output.logits, batch["input_ids"], scale_factor)
log_lik = logits_to_log_lik(output.logits, batch["input_ids"])
return log_lik, output


Expand All @@ -134,10 +124,12 @@ def normal_log_prior(params, prior_mean, prior_sd) -> float:


# Combine likelihod and prior (on LoRA params only)
def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor):
log_lik, output = sub_param_to_log_likelihood(p, batch, scale_factor)
# Gives unbiased estimate of log posterior / num_sequences
# Where num_sequences is the number of sequences in the training data
def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, num_sequences):
log_lik, output = sub_param_to_log_likelihood(p, batch)
log_prior = normal_log_prior(p, prior_mean, prior_sd)
log_post = log_lik + log_prior
log_post = log_lik + log_prior / num_sequences
return log_post, output


Expand All @@ -163,13 +155,7 @@ def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor):
sub_param_to_log_posterior,
prior_mean=current_prior_mean,
prior_sd=current_prior_sd,
scale_factor=train_total_predict_tokens[episode_ind],
)

# For validation we'll just assess the loss (expected token negative log likelihood)
test_sub_param_to_mean_log_lik = partial(
sub_param_to_log_likelihood,
scale_factor=1,
num_sequences=len(train_dataloaders[episode_ind].dataset.data["input_ids"]),
)

# Train for MAP
Expand Down Expand Up @@ -203,7 +189,7 @@ def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor):
losses = []
for batch in test_dl:
batch = optree.tree_map(lambda x: x.to(args.device), batch)
log_lik, output = test_sub_param_to_mean_log_lik(sub_params, batch)
log_lik, output = sub_param_to_log_likelihood(sub_params, batch)
losses.append(-log_lik.item())

val_loss = sum(losses) / len(losses)
Expand All @@ -226,17 +212,8 @@ def sub_param_to_log_posterior(p, batch, prior_mean, prior_sd, scale_factor):
if config.lambda_param > 0.0 and episode_ind < config.num_tasks - 1:
print(f"Fitting Laplace on episode {episode_ind + 1} of {config.num_tasks}")

# For Laplace we need to rescale up by predicted sequence length,
# as uqlib already sums across batches
laplace_sub_param_to_log_likelihood = partial(
sub_param_to_log_likelihood,
scale_factor=config.stride_length - config.model_config.ignore_first,
)

# Get Laplace precision diag
laplace_transform = uqlib.laplace.diag_fisher.build(
laplace_sub_param_to_log_likelihood
)
laplace_transform = uqlib.laplace.diag_fisher.build(sub_param_to_log_likelihood)

# Update Laplace state through data
laplace_state = laplace_transform.init(sub_params)
Expand Down

0 comments on commit a3155cb

Please sign in to comment.