Skip to content

Commit

Permalink
reverts some indexing in IITBehaviorModelPair that caused circuits-be…
Browse files Browse the repository at this point in the history
…nch to break
  • Loading branch information
evanhanders committed Aug 23, 2024
1 parent 0d1b24f commit 313309e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions iit/model_pairs/iit_behavior_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]:

# compute behavioral accuracy
base_x, base_y = base_input[0:2]
output = self.ll_model(base_x)[label_idx.as_index] #convert ll logits -> one-hot max label
output = self.ll_model(base_x)
if self.hl_model.is_categorical():
top1 = t.argmax(output, dim=-1)
if output.shape[-1] == base_y.shape[-1]:
if output.shape == base_y.shape:
# To handle the case when labels are one-hot
# TODO: is there a better way?
base_y = t.argmax(base_y, dim=-1).squeeze()
base_y = t.argmax(base_y, dim=-1)
accuracy = (top1 == base_y).float().mean()
else:
accuracy = ((output - base_y).abs() < atol).float().mean()
Expand Down

0 comments on commit 313309e

Please sign in to comment.