Skip to content

Commit

Permalink
fix(evals): performs logits mask when computing ce score. Ignoring pa…
Browse files Browse the repository at this point in the history
…d tokens
  • Loading branch information
Hzfinfdu committed Jun 26, 2024
1 parent 4c11b5d commit 6f11837
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/lm_saes/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def recons_loss_batched(
n_batches: int = 100,
):
losses = []
if (not cfg.use_ddp or cfg.rank == 0):
if not cfg.use_ddp or cfg.rank == 0:
pbar = tqdm(total=n_batches, desc="Evaluation", smoothing=0.01)
for _ in range(n_batches):
batch_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size)
Expand All @@ -118,10 +118,10 @@ def recons_loss_batched(
zero_abl_loss.mean().item(),
)
)
if (not cfg.use_ddp or cfg.rank == 0):
if not cfg.use_ddp or cfg.rank == 0:
pbar.update(1)

if (not cfg.use_ddp or cfg.rank == 0):
if not cfg.use_ddp or cfg.rank == 0:
pbar.close()

losses = pd.DataFrame(
Expand All @@ -139,7 +139,8 @@ def get_recons_loss(
batch_tokens: torch.Tensor,
):
batch_tokens = batch_tokens.to(torch.int64)
loss = model.forward(batch_tokens, return_type="loss")

loss = model.forward(batch_tokens, return_type="loss", loss_per_token=True)

_, cache = model.run_with_cache_until(
batch_tokens,
Expand All @@ -157,12 +158,23 @@ def replacement_hook(activations: torch.Tensor, hook: Any):
batch_tokens,
return_type="loss",
fwd_hooks=[(cfg.sae.hook_point_out, replacement_hook)],
loss_per_token=True
)

zero_abl_loss: torch.Tensor = model.run_with_hooks(
batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)]
batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)], loss_per_token=True
)

logits_mask = torch.logical_and(batch_tokens.ne(model.tokenizer.eos_token_id), batch_tokens.ne(model.tokenizer.pad_token_id))
logits_mask = torch.logical_and(logits_mask, batch_tokens.ne(model.tokenizer.bos_token_id))
logits_mask = logits_mask[:, 1:]

def get_useful_token_loss(per_token_loss):
per_token_loss = per_token_loss.where(logits_mask, 0)
return per_token_loss.sum() / per_token_loss.ne(0).sum()

loss, recons_loss, zero_abl_loss = get_useful_token_loss(loss), get_useful_token_loss(recons_loss), get_useful_token_loss(zero_abl_loss)

score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss)

return score, loss, recons_loss, zero_abl_loss
Expand Down

0 comments on commit 6f11837

Please sign in to comment.