diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 1fb5ee2..61da1c1 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -234,9 +234,9 @@ def train( assert isinstance(test_set, IITDataset), ValueError( f"test_set is not an instance of IITDataset, but {type(test_set)}" ) - assert self.ll_model.device == self.hl_model.device, ValueError( - "ll_model and hl_model are not on the same device" - ) + # assert self.ll_model.cfg.device == self.hl_model.device, ValueError( + # "ll_model and hl_model are not on the same device" + # ) train_loader, test_loader = self.make_loaders( train_set, diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index c871466..79ff2bc 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -2,7 +2,7 @@ import torch as t from torch import Tensor -from transformer_lens.hook_points import HookedRootModule #type: ignore +from transformer_lens.hook_points import HookedRootModule # type: ignore from iit.model_pairs.iit_model_pair import IITModelPair from iit.model_pairs.ll_model import LLModel @@ -13,18 +13,19 @@ class IITBehaviorModelPair(IITModelPair): def __init__( - self, - hl_model: HookedRootModule, - ll_model: LLModel, - corr: Correspondence, - training_args: dict = {} - ): + self, + hl_model: HookedRootModule, + ll_model: LLModel, + corr: Correspondence, + training_args: dict = {}, + ): default_training_args = { "atol": 5e-2, "use_single_loss": False, "iit_weight": 1.0, "behavior_weight": 1.0, - "val_IIA_sampling": "random", # random or all + "val_IIA_sampling": "random", # random or all + "use_all_tokens_for_behavior": False, # if True, all tokens are used for behavior loss, else only the tokens in label_idxs } training_args = {**default_training_args, **training_args} super().__init__(hl_model, ll_model, corr=corr, training_args=training_args) @@ -50,19 +51,21 @@ def make_test_metrics() -> MetricStoreCollection: ) def get_behaviour_loss_over_batch( - self, - base_input: tuple[Tensor, Tensor, Tensor], - loss_fn: Callable[[Tensor, Tensor], Tensor] - ) -> Tensor: + self, base_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor] + ) -> Tensor: base_x, base_y = base_input[0:2] output = self.ll_model(base_x) label_indx = self.get_label_idxs() - behaviour_loss = loss_fn(output[label_indx.as_index], base_y[label_indx.as_index].to(output.device)) + if not self.training_args["use_all_tokens_for_behavior"]: + output = output[label_indx.as_index] + base_y = base_y[label_indx.as_index] + base_y = base_y.to(output.device) + behaviour_loss = loss_fn(output, base_y) return behaviour_loss def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None: optimizer.zero_grad() - loss.backward() # type: ignore + loss.backward() # type: ignore self.clip_grad_fn() optimizer.step() @@ -127,7 +130,7 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]: loss = loss_fn(ll_output, hl_output) IIA = ((ll_output - hl_output).abs() < atol).float().mean().item() return IIA, loss - + if self.training_args["val_IIA_sampling"] == "random": hl_node = self.sample_hl_name() IIA, loss = get_node_IIT_info(hl_node) @@ -145,8 +148,13 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]: # compute behavioral accuracy base_x, base_y = base_input[0:2] - base_y = base_y.to(self.ll_model.device) # so that input data doesn't all need to be hogging room on device. output = self.ll_model(base_x) + base_y = base_y.to( + output.device + ) # so that input data doesn't all need to be hogging room on device. + if not self.training_args["use_all_tokens_for_behavior"]: + output = output[label_idx.as_index] + base_y = base_y[label_idx.as_index] if self.hl_model.is_categorical(): top1 = t.argmax(output, dim=-1) if output.shape == base_y.shape: @@ -161,7 +169,6 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]: "val/IIA": IIA, "val/accuracy": accuracy.item(), } - def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool: if self.training_args["iit_weight"] == 0: @@ -170,4 +177,4 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo return metric.get_value() == 100 else: return super()._check_early_stop_condition(test_metrics) - return False \ No newline at end of file + return False