Skip to content

Commit

Permalink
fix some bugs for behavior loss, temporarily disable asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Sep 4, 2024
1 parent 7984bca commit 32c8e6e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
6 changes: 3 additions & 3 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def train(
assert isinstance(test_set, IITDataset), ValueError(
f"test_set is not an instance of IITDataset, but {type(test_set)}"
)
assert self.ll_model.device == self.hl_model.device, ValueError(
"ll_model and hl_model are not on the same device"
)
# assert self.ll_model.cfg.device == self.hl_model.device, ValueError(
# "ll_model and hl_model are not on the same device"
# )

train_loader, test_loader = self.make_loaders(
train_set,
Expand Down
43 changes: 25 additions & 18 deletions iit/model_pairs/iit_behavior_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch as t
from torch import Tensor
from transformer_lens.hook_points import HookedRootModule #type: ignore
from transformer_lens.hook_points import HookedRootModule # type: ignore

from iit.model_pairs.iit_model_pair import IITModelPair
from iit.model_pairs.ll_model import LLModel
Expand All @@ -13,18 +13,19 @@

class IITBehaviorModelPair(IITModelPair):
def __init__(
self,
hl_model: HookedRootModule,
ll_model: LLModel,
corr: Correspondence,
training_args: dict = {}
):
self,
hl_model: HookedRootModule,
ll_model: LLModel,
corr: Correspondence,
training_args: dict = {},
):
default_training_args = {
"atol": 5e-2,
"use_single_loss": False,
"iit_weight": 1.0,
"behavior_weight": 1.0,
"val_IIA_sampling": "random", # random or all
"val_IIA_sampling": "random", # random or all
"use_all_tokens_for_behavior": False, # if True, all tokens are used for behavior loss, else only the tokens in label_idxs
}
training_args = {**default_training_args, **training_args}
super().__init__(hl_model, ll_model, corr=corr, training_args=training_args)
Expand All @@ -50,19 +51,21 @@ def make_test_metrics() -> MetricStoreCollection:
)

def get_behaviour_loss_over_batch(
self,
base_input: tuple[Tensor, Tensor, Tensor],
loss_fn: Callable[[Tensor, Tensor], Tensor]
) -> Tensor:
self, base_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor]
) -> Tensor:
base_x, base_y = base_input[0:2]
output = self.ll_model(base_x)
label_indx = self.get_label_idxs()
behaviour_loss = loss_fn(output[label_indx.as_index], base_y[label_indx.as_index].to(output.device))
if not self.training_args["use_all_tokens_for_behavior"]:
output = output[label_indx.as_index]
base_y = base_y[label_indx.as_index]
base_y = base_y.to(output.device)
behaviour_loss = loss_fn(output, base_y)
return behaviour_loss

def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None:
optimizer.zero_grad()
loss.backward() # type: ignore
loss.backward() # type: ignore
self.clip_grad_fn()
optimizer.step()

Expand Down Expand Up @@ -127,7 +130,7 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
return IIA, loss

if self.training_args["val_IIA_sampling"] == "random":
hl_node = self.sample_hl_name()
IIA, loss = get_node_IIT_info(hl_node)
Expand All @@ -145,8 +148,13 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]:

# compute behavioral accuracy
base_x, base_y = base_input[0:2]
base_y = base_y.to(self.ll_model.device) # so that input data doesn't all need to be hogging room on device.
output = self.ll_model(base_x)
base_y = base_y.to(
output.device
) # so that input data doesn't all need to be hogging room on device.
if not self.training_args["use_all_tokens_for_behavior"]:
output = output[label_idx.as_index]
base_y = base_y[label_idx.as_index]
if self.hl_model.is_categorical():
top1 = t.argmax(output, dim=-1)
if output.shape == base_y.shape:
Expand All @@ -161,7 +169,6 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]:
"val/IIA": IIA,
"val/accuracy": accuracy.item(),
}


def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool:
if self.training_args["iit_weight"] == 0:
Expand All @@ -170,4 +177,4 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
return metric.get_value() == 100
else:
return super()._check_early_stop_condition(test_metrics)
return False
return False

0 comments on commit 32c8e6e

Please sign in to comment.