diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 0000000..7458094 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,24 @@ +name: Mypy Type Checking + +on: [push, pull_request] + +jobs: + mypy: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy + - name: Run mypy + run: | + mypy --config-file mypy-${{ matrix.python-version }}.ini diff --git a/eval_causality.py b/eval_causality.py index 2ebcfec..dda3582 100644 --- a/eval_causality.py +++ b/eval_causality.py @@ -2,6 +2,7 @@ from datetime import datetime import torch as t +from torch import Tensor import wandb from tqdm import tqdm @@ -10,6 +11,7 @@ from iit.utils.config import DEVICE from iit.utils.plotter import plot_ablation_stats from iit.utils.wrapper import get_hook_points +from iit.tasks.mnist_pvr.dataset import ImagePVRDataset def evaluate_model_on_ablations( @@ -18,7 +20,7 @@ def evaluate_model_on_ablations( test_set: t.utils.data.Dataset, eval_args: dict, verbose: bool = False, -): +) -> dict: print("reached evaluate_model!") stats_per_layer = {} for hook_point in tqdm(get_hook_points(ll_model), desc="Hook points"): @@ -26,7 +28,7 @@ def evaluate_model_on_ablations( task, config={ "hook_point": hook_point, - "input_shape": test_set.get_input_shape(), + "input_shape": test_set.get_input_shape(), # type: ignore }, ) model_pair = IITProbeSequentialPair( @@ -40,23 +42,26 @@ def evaluate_model_on_ablations( # set up stats hookpoint_stats = {} for hl_node, _ in model_pair.corr.items(): - hookpoint_stats[hl_node] = 0 + hookpoint_stats[hl_node] = t.zeros(1) # find test accuracy with t.no_grad(): for base_input_lists in tqdm(dataloader, desc=f"Ablations on {hook_point}"): - base_input = [x.to(DEVICE) for x in base_input_lists] + base_input: tuple[Tensor, Tensor, Tensor] = (x.to(DEVICE) for x in base_input_lists) # type: ignore for hl_node, ll_nodes in model_pair.corr.items(): - ablated_input = test_set.patch_batch_at_hl( - list(base_input[0]), - list(base_input_lists[-1]), - hl_node, - list(base_input[1]), - ) - ablated_input = ( - t.stack(ablated_input[0]).to(DEVICE), # input - t.stack(ablated_input[1]).to(DEVICE), # label - t.stack(ablated_input[2]).to(DEVICE), - ) # intermediate_data + if isinstance(test_set, ImagePVRDataset): + ablated_input_pre = test_set.patch_batch_at_hl( + list(base_input), + list(base_input_lists), + hl_node, + ) + ablated_input = ( + t.stack(ablated_input_pre[0]).to(DEVICE), # input + t.stack(ablated_input_pre[1]).to(DEVICE), # label + t.stack(ablated_input_pre[2]).to(DEVICE), + ) # intermediate_data + else: + raise ValueError(f"patch_batch_at_hl not implemented for this dataset type: {type(test_set)}") + # unsqueeze if single element if ablated_input[1].shape == (): assert ( @@ -115,6 +120,8 @@ def evaluate_model_on_ablations( ll_model, hl_model, corr = get_alignment( task, config={"input_shape": test_set.get_input_shape()} ) + assert ll_model is not None + assert hl_model is not None model_pair = IITProbeSequentialPair( ll_model=ll_model, hl_model=hl_model, corr=corr, training_args=training_args ) @@ -122,7 +129,7 @@ def evaluate_model_on_ablations( model_pair.train( train_set, test_set, - epochs=training_args["epochs"], + epochs=int(training_args["epochs"]), use_wandb=use_wandb, ) else: @@ -147,8 +154,8 @@ def evaluate_model_on_ablations( if use_wandb: wandb.init(project="iit") - wandb.run.name = f"{leaky_task}_ablation" - wandb.run.save() + wandb.run.name = f"{leaky_task}_ablation" # type: ignore + wandb.run.save() # type: ignore wandb.config.update(eval_args) leaky_stats_per_layer = evaluate_model_on_ablations( diff --git a/eval_information.py b/eval_information.py index c75f549..d39ad5b 100644 --- a/eval_information.py +++ b/eval_information.py @@ -7,6 +7,7 @@ import torch as t from tqdm import tqdm from iit.utils.plotter import plot_probe_stats +from iit.utils.iit_dataset import IITDataset import os import wandb from datetime import datetime @@ -21,14 +22,14 @@ def evaluate_model_on_probes( use_wandb: bool = False, verbose: bool = False, save_probes: bool = False, -): +) -> dict: print("reached evaluate_model!") probe_stats_per_layer = {} log_stats_per_layer = {} if use_wandb: wandb.init(project="iit") - wandb.run.name = f"{task}_probes" - wandb.run.save() + wandb.run.name = f"{task}_probes" # type: ignore + wandb.run.save() # type: ignore # add training args to wandb config wandb.config.update(probe_training_args) @@ -37,7 +38,7 @@ def evaluate_model_on_probes( task, config={ "hook_point": hook_point, - "input_shape": test_set.get_input_shape(), + "input_shape": test_set.get_input_shape(), # type: ignore }, ) model_pair = IITProbeSequentialPair( @@ -47,7 +48,7 @@ def evaluate_model_on_probes( training_args=probe_training_args, ) - input_shape = train_set.get_input_shape() + input_shape = train_set.get_input_shape() # type: ignore trainer_out = train_probes_on_model_pair( model_pair, input_shape, train_set, probe_training_args ) @@ -103,13 +104,15 @@ def evaluate_model_on_probes( ll_model, hl_model, corr = get_alignment( task, config={"input_shape": test_set.get_input_shape()} ) + assert ll_model is not None + assert hl_model is not None model_pair = IITProbeSequentialPair( ll_model=ll_model, hl_model=hl_model, corr=corr, training_args=training_args ) model_pair.train( train_set, test_set, - epochs=training_args["epochs"], + epochs=int(training_args["epochs"]), use_wandb=use_wandb, ) if use_wandb: diff --git a/eval_ioi.py b/eval_ioi.py index 6f24c52..b728cd2 100644 --- a/eval_ioi.py +++ b/eval_ioi.py @@ -1,4 +1,5 @@ import argparse +from iit.utils.argparsing import IOIArgParseNamespace from iit.utils.eval_scripts import eval_ioi import torch @@ -45,4 +46,5 @@ ) args = parser.parse_args() - eval_ioi(args) + namespace = IOIArgParseNamespace(**vars(args)) + eval_ioi(namespace) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index ff44f3f..cf17c88 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -1,20 +1,21 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, final +from typing import Any, Callable, final, Type import numpy as np import torch as t -import transformer_lens as tl +import transformer_lens as tl # type: ignore from torch import Tensor -from tqdm import tqdm -from transformer_lens.hook_points import HookedRootModule, HookPoint +from torch.utils.data import DataLoader +from tqdm import tqdm # type: ignore +from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore -import wandb +import wandb # type: ignore from iit.model_pairs.ll_model import LLModel from iit.utils.nodes import HLNode, LLNode from iit.utils.config import WANDB_ENTITY from iit.utils.correspondence import Correspondence from iit.utils.iit_dataset import IITDataset -from iit.utils.index import Ix +from iit.utils.index import Ix, TorchIndex from iit.utils.metric import MetricStoreCollection, MetricType @@ -50,20 +51,20 @@ def make_test_metrics() -> MetricStoreCollection: @abstractmethod def run_train_step( self, - base_input, - ablation_input, - loss_fn, - optimizer, - ) -> MetricStoreCollection: + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], + loss_fn: Callable[[Tensor, Tensor], Tensor], + optimizer: t.optim.Optimizer, + ) -> dict: pass @abstractmethod def run_eval_step( self, - base_input, - ablation_input, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], - ) -> MetricStoreCollection: + ) -> dict: pass ########################################### @@ -71,10 +72,10 @@ def run_eval_step( ########################################### def do_intervention( self, - base_input: tuple[t.Tensor, t.Tensor, t.Tensor], - ablation_input: tuple[t.Tensor, t.Tensor, t.Tensor], + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], hl_node: HLNode, - verbose=False + verbose: bool = False ) -> tuple[Tensor, Tensor]: ablation_x, ablation_y = ablation_input[0:2] base_x, base_y = base_input[0:2] @@ -102,23 +103,20 @@ def do_intervention( return hl_output, ll_output @staticmethod - def get_label_idxs(): + def get_label_idxs() -> TorchIndex: ''' Returns the index of the label for which the IIT loss is computed. NOT to be used for computing the behavior loss. ''' return Ix[[None]] - def make_hl_model(self, hl_graph): - raise NotImplementedError - - def set_corr(self, corr): + def set_corr(self, corr: Correspondence) -> None: self.corr = corr def sample_hl_name(self) -> HLNode: - return self.rng.choice(list(self.corr.keys())) + return self.rng.choice(np.array(list(self.corr.keys()))) - def make_hl_ablation_hook(self, hl_node: HLNode): + def make_hl_ablation_hook(self, hl_node: HLNode) -> Callable[[Tensor, HookPoint], Tensor]: assert isinstance(hl_node, HLNode), ValueError( f"hl_node is not an instance of HLNode, but {type(hl_node)}" ) @@ -143,7 +141,9 @@ def hl_ablation_hook(hook_point_out: Tensor, hook: HookPoint) -> Tensor: return self.hl_ablation_hook def hl_ablation_hook( - self, hook_point_out: Tensor, hook: HookPoint + self, + hook_point_out: Tensor, + hook: HookPoint ) -> Tensor: # TODO: remove this out = self.hl_cache[hook.name] return out @@ -169,24 +169,28 @@ def ll_ablation_hook(hook_point_out: Tensor, hook: HookPoint) -> Tensor: def get_IIT_loss_over_batch( self, - base_input, - ablation_input, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], hl_node: HLNode, loss_fn: Callable[[Tensor, Tensor], Tensor], - ): + ) -> Tensor: hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) label_idx = self.get_label_idxs() # IIT loss is only computed on the tokens we care about loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index]) return loss - def clip_grad_fn(self): + def clip_grad_fn(self) -> None: if self.training_args["clip_grad_norm"]: t.nn.utils.clip_grad_norm_( self.ll_model.parameters(), self.training_args["clip_grad_norm"] ) - def step_scheduler(self, lr_scheduler, test_metrics): + def step_scheduler( + self, + lr_scheduler: t.optim.lr_scheduler.LRScheduler, + test_metrics: MetricStoreCollection + ) -> None: if isinstance(lr_scheduler, t.optim.lr_scheduler.ReduceLROnPlateau): accuracy_metric = 0 for metric in self.training_args.get("scheduler_val_metric", ["val/accuracy"]): @@ -212,12 +216,14 @@ def step_scheduler(self, lr_scheduler, test_metrics): def train( self, - train_set, - test_set, - epochs=1000, - use_wandb=False, - wandb_name_suffix="", - ): + train_set: IITDataset, + test_set: IITDataset, + optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam, + epochs: int = 1000, + use_wandb: bool = False, + wandb_name_suffix: str = "", + optimizer_kwargs: dict = {}, + ) -> None: training_args = self.training_args print(f"{training_args=}") @@ -236,7 +242,8 @@ def train( early_stop = training_args["early_stop"] - optimizer = t.optim.Adam(self.ll_model.parameters(), lr=training_args["lr"]) + optimizer_kwargs['lr'] = training_args["lr"] + optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs) loss_fn = self.loss_fn scheduler_cls = training_args.get("lr_scheduler", None) if scheduler_cls == t.optim.lr_scheduler.ReduceLROnPlateau: @@ -252,7 +259,7 @@ def train( if use_wandb: wandb.config.update(training_args) wandb.config.update({"method": self.wandb_method}) - wandb.run.log_code() + wandb.run.log_code() # type: ignore for epoch in tqdm(range(epochs)): train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer) @@ -262,10 +269,10 @@ def train( self.test_metrics = test_metrics self.train_metrics = train_metrics self._print_and_log_metrics( - epoch, train_metrics.metrics + test_metrics.metrics, use_wandb + epoch, MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), use_wandb ) - if early_stop and self._check_early_stop_condition(test_metrics.metrics): + if early_stop and self._check_early_stop_condition(test_metrics): break if use_wandb: @@ -279,15 +286,20 @@ def train( def make_loaders( dataset: IITDataset, test_dataset: IITDataset, - batch_size, - num_workers, - ): + batch_size : int, + num_workers : int, + ) -> tuple[DataLoader, DataLoader]: loader = dataset.make_loader(batch_size, num_workers) test_loader = test_dataset.make_loader(batch_size, num_workers) return loader, test_loader @final - def _run_train_epoch(self, loader, loss_fn, optimizer) -> MetricStoreCollection: + def _run_train_epoch( + self, + loader: DataLoader, + loss_fn: Callable[[Tensor, Tensor], Tensor], + optimizer: t.optim.Optimizer + ) -> MetricStoreCollection: self.ll_model.train() train_metrics = self.make_train_metrics() for i, (base_input, ablation_input) in tqdm( @@ -299,7 +311,11 @@ def _run_train_epoch(self, loader, loss_fn, optimizer) -> MetricStoreCollection: return train_metrics @final - def _run_eval_epoch(self, loader, loss_fn) -> MetricStoreCollection: + def _run_eval_epoch( + self, + loader: DataLoader, + loss_fn: Callable[[Tensor, Tensor], Tensor] + ) -> MetricStoreCollection: self.ll_model.eval() test_metrics = self.make_test_metrics() with t.no_grad(): @@ -309,8 +325,7 @@ def _run_eval_epoch(self, loader, loss_fn) -> MetricStoreCollection: ) return test_metrics - @staticmethod - def _check_early_stop_condition(test_metrics): + def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool: """ Returns True if all types of accuracy metrics reach 100% """ @@ -318,7 +333,8 @@ def _check_early_stop_condition(test_metrics): for metric in test_metrics: if metric.type == MetricType.ACCURACY: got_accuracy_metric = True - if metric.get_value() < 100: + val = metric.get_value() + if isinstance(val, float) and val < 100: return False if not got_accuracy_metric: raise ValueError("No accuracy metric found in test_metrics!") @@ -326,7 +342,11 @@ def _check_early_stop_condition(test_metrics): @final @staticmethod - def _print_and_log_metrics(epoch, metrics, use_wandb=False): + def _print_and_log_metrics( + epoch: int, + metrics: MetricStoreCollection, + use_wandb: bool = False + ) -> None: print(f"\nEpoch {epoch}:", end=" ") if use_wandb: wandb.log({"epoch": epoch}) diff --git a/iit/model_pairs/freeze_model_pair.py b/iit/model_pairs/freeze_model_pair.py index e14ecc8..141f207 100644 --- a/iit/model_pairs/freeze_model_pair.py +++ b/iit/model_pairs/freeze_model_pair.py @@ -1,10 +1,21 @@ -from iit.model_pairs.iit_behavior_model_pair import IITBehaviorModelPair -import iit.utils.node_picker as node_picker + import torch as t +from transformer_lens.hook_points import HookedRootModule #type: ignore +from iit.model_pairs.iit_behavior_model_pair import IITBehaviorModelPair +from iit.model_pairs.ll_model import LLModel +from iit.utils.correspondence import Correspondence +import iit.utils.node_picker as node_picker class FreezedModelPair(IITBehaviorModelPair): - def __init__(self, hl_model, ll_model, corr, training_args={}): + + def __init__( + self, + hl_model: HookedRootModule, + ll_model: LLModel, + corr: Correspondence, + training_args: dict = {} + ): default_training_args = { "batch_size": 256, "lr": 0.001, @@ -18,7 +29,7 @@ def __init__(self, hl_model, ll_model, corr, training_args={}): self.params_not_in_circuit = node_picker.get_params_not_in_circuit(corr, ll_model) self.wandb_method = "freeze_unwanted" - def zero_grad_for_not_in_circuit(self): + def zero_grad_for_not_in_circuit(self) -> None: for ll_node in self.params_not_in_circuit: for name, param in self.ll_model.named_parameters(): if ll_node.name == name: @@ -28,9 +39,9 @@ def zero_grad_for_not_in_circuit(self): assert param.grad is not None # assert (param.grad.abs().sum() != 0) or (param_idx.as_index == index.Ix[[None]].as_index), f"got {param.grad.abs().sum()} and {param_idx.as_index} and {index.Ix[[None]].as_index}" - def step_on_loss(self, loss, optimizer): + def step_on_loss(self, loss: t.Tensor, optimizer: t.optim.Optimizer) -> None: optimizer.zero_grad() - loss.backward() + loss.backward() # type: ignore self.zero_grad_for_not_in_circuit() # else, no need as we do it via hooks self.clip_grad_fn() optimizer.step() diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index fc4a05a..4a74742 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -2,13 +2,22 @@ import torch as t from torch import Tensor +from transformer_lens.hook_points import HookedRootModule #type: ignore from iit.model_pairs.iit_model_pair import IITModelPair +from iit.model_pairs.ll_model import LLModel +from iit.utils.correspondence import Correspondence from iit.utils.metric import MetricStore, MetricStoreCollection, MetricType class IITBehaviorModelPair(IITModelPair): - def __init__(self, hl_model, ll_model, corr, training_args={}): + def __init__( + self, + hl_model: HookedRootModule, + ll_model: LLModel, + corr: Correspondence, + training_args: dict = {} + ): default_training_args = { "lr": 0.001, "atol": 5e-2, @@ -22,7 +31,7 @@ def __init__(self, hl_model, ll_model, corr, training_args={}): self.wandb_method = "iit_and_behavior" @staticmethod - def make_train_metrics(): + def make_train_metrics() -> MetricStoreCollection: return MetricStoreCollection( [ MetricStore("train/iit_loss", MetricType.LOSS), @@ -31,7 +40,7 @@ def make_train_metrics(): ) @staticmethod - def make_test_metrics(): + def make_test_metrics() -> MetricStoreCollection: return MetricStoreCollection( [ MetricStore("val/iit_loss", MetricType.LOSS), @@ -40,30 +49,31 @@ def make_test_metrics(): ] ) - def get_behaviour_loss_over_batch(self, base_input, loss_fn): + def get_behaviour_loss_over_batch( + self, + base_input: tuple[Tensor, Tensor, Tensor], + loss_fn: Callable[[Tensor, Tensor], Tensor] + ) -> Tensor: base_x, base_y = base_input[0:2] output = self.ll_model(base_x) behavior_loss = loss_fn(output, base_y) return behavior_loss - def step_on_loss(self, loss, optimizer): + def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None: optimizer.zero_grad() - loss.backward() + loss.backward() # type: ignore self.clip_grad_fn() optimizer.step() def run_train_step( self, - base_input, - ablation_input, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], optimizer: t.optim.Optimizer, - ): + ) -> dict: use_single_loss = self.training_args["use_single_loss"] - iit_loss = 0 - behavior_loss = 0 - hl_node = self.sample_hl_name() # sample a high-level variable to ablate iit_loss = ( self.get_IIT_loss_over_batch(base_input, ablation_input, hl_node, loss_fn) @@ -90,10 +100,10 @@ def run_train_step( def run_eval_step( self, - base_input, - ablation_input, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], - ): + ) -> dict: atol = self.training_args["atol"] # compute IIT loss and accuracy @@ -134,10 +144,11 @@ def run_eval_step( } - def _check_early_stop_condition(self, test_metrics: list[MetricStore]): + def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool: if self.training_args["iit_weight"] == 0: for metric in test_metrics: if metric.get_name() == "val/accuracy": return metric.get_value() == 100 else: - return super()._check_early_stop_condition(test_metrics) \ No newline at end of file + return super()._check_early_stop_condition(test_metrics) + return False \ No newline at end of file diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index 6f01062..8cd720d 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -3,10 +3,9 @@ import numpy as np import torch as t from torch import Tensor -from transformer_lens.hook_points import HookedRootModule +from transformer_lens.hook_points import HookedRootModule #type: ignore from iit.model_pairs.base_model_pair import BaseModelPair -from iit.utils.nodes import HLNode, LLNode from iit.utils.correspondence import Correspondence from iit.utils.metric import MetricStore, MetricStoreCollection, MetricType from iit.model_pairs.ll_model import LLModel @@ -17,13 +16,13 @@ def __init__( self, hl_model: HookedRootModule, ll_model: HookedRootModule | LLModel, - corr: "Correspondence", - training_args={}, + corr: Correspondence, + training_args: dict = {}, ): self.hl_model = hl_model self.hl_model.requires_grad_(False) - self.corr: dict[HLNode, set[LLNode]] = corr + self.corr = corr print(self.hl_model.hook_dict) print(self.corr.keys()) assert all([str(k) in self.hl_model.hook_dict for k in self.corr.keys()]) @@ -50,9 +49,9 @@ def __init__( self.wandb_method = "iit" @property - def loss_fn(self): + def loss_fn(self) -> Callable[[Tensor, Tensor], Tensor]: # TODO: make this more general - def class_loss(output, target): + def class_loss(output: Tensor, target: Tensor) -> Tensor: # convert to (N, C, ...) if necessary if len(target.shape) == len(output.shape) and len(output.shape) > 2: # convert target to float if necessary @@ -79,7 +78,7 @@ def class_loss(output, target): return class_loss @staticmethod - def make_train_metrics(): + def make_train_metrics() -> MetricStoreCollection: return MetricStoreCollection( [ MetricStore("train/iit_loss", MetricType.LOSS), @@ -87,7 +86,7 @@ def make_train_metrics(): ) @staticmethod - def make_test_metrics(): + def make_test_metrics() -> MetricStoreCollection: return MetricStoreCollection( [ MetricStore("val/iit_loss", MetricType.LOSS), @@ -97,10 +96,10 @@ def make_test_metrics(): def run_eval_step( self, - base_input: tuple[t.Tensor, t.Tensor, t.Tensor], - ablation_input: tuple[t.Tensor, t.Tensor, t.Tensor], + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], - ): + ) -> dict: hl_node = self.sample_hl_name() # sample a high-level variable to ablate hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) loss = loss_fn(ll_output, hl_output) @@ -113,16 +112,16 @@ def run_eval_step( def run_train_step( self, - base_input: tuple[t.Tensor, t.Tensor, t.Tensor], - ablation_input: tuple[t.Tensor, t.Tensor, t.Tensor], + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], optimizer: t.optim.Optimizer, - ): + ) -> dict: optimizer.zero_grad() hl_node = self.sample_hl_name() # sample a high-level variable to ablate loss = self.get_IIT_loss_over_batch( base_input, ablation_input, hl_node, loss_fn ) - loss.backward() + loss.backward() # type: ignore optimizer.step() return {"train/iit_loss": loss.item()} diff --git a/iit/model_pairs/ioi_model_pair.py b/iit/model_pairs/ioi_model_pair.py index 6a25270..27d3b37 100644 --- a/iit/model_pairs/ioi_model_pair.py +++ b/iit/model_pairs/ioi_model_pair.py @@ -1,16 +1,27 @@ +from typing import Callable, Optional + +import numpy as np +import torch as t +from torch import Tensor +from transformer_lens.hook_points import HookedRootModule #type: ignore + from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair +from iit.model_pairs.ll_model import LLModel +from iit.utils.correspondence import Correspondence from iit.utils.config import DEVICE from iit.utils.metric import MetricStore, MetricType, MetricStoreCollection, PerTokenMetricStore -import numpy as np -from typing import Callable -from torch import Tensor -import torch as t from iit.utils.nodes import HLNode import iit.utils.index as index class IOI_ModelPair(StrictIITModelPair): - def __init__(self, hl_model, ll_model, corr, training_args={}): + def __init__( + self, + hl_model: HookedRootModule, + ll_model: LLModel, + corr: Correspondence, + training_args: dict = {} + ): super().__init__(hl_model, ll_model, corr, training_args=training_args) default_training_args = { "next_token": False, @@ -19,13 +30,14 @@ def __init__(self, hl_model, ll_model, corr, training_args={}): } self.training_args = {**default_training_args, **self.training_args} self.next_token = self.training_args["next_token"] + self.__loss_fn: Optional[Callable[[Tensor, Tensor], Tensor]] = None @property - def loss_fn(self): - if hasattr(self, "__loss_fn"): + def loss_fn(self) -> Callable[[Tensor, Tensor], Tensor]: + if self.__loss_fn is not None: return self.__loss_fn - def per_token_weighted_cross_entropy(output, target): + def per_token_weighted_cross_entropy(output: Tensor, target: Tensor) -> Tensor: if target.shape == output.shape: target = target.argmax(dim=-1) # convert one-hot to index for cross_entropy if len(output.shape) == 2: @@ -46,11 +58,11 @@ def per_token_weighted_cross_entropy(output, target): return self.__loss_fn @staticmethod - def get_label_idxs(): + def get_label_idxs() -> index.TorchIndex: return index.Ix[:, -1] @staticmethod - def make_test_metrics(): + def make_test_metrics() -> MetricStoreCollection: return MetricStoreCollection( [ MetricStore("val/iit_loss", MetricType.LOSS), @@ -63,11 +75,11 @@ def make_test_metrics(): def get_IIT_loss_over_batch( self, - base_input: tuple[t.Tensor, t.Tensor, t.Tensor], - ablation_input: tuple[t.Tensor, t.Tensor, t.Tensor], + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], hl_node: HLNode, loss_fn: Callable[[Tensor, Tensor], Tensor], - ): + ) -> Tensor: hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) # hl_output = t.nn.functional.softmax(hl_output, dim=-1) hl_argmax = t.argmax(hl_output[:, -1, :], dim=-1) @@ -77,10 +89,10 @@ def get_IIT_loss_over_batch( def run_eval_step( self, - base_input, - ablation_input, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], - ): + ) -> dict: # compute IIT loss and accuracy on last token position only hl_node = self.sample_hl_name() hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) @@ -95,8 +107,8 @@ def run_eval_step( # To handle the case when labels are one-hot hl_output = t.argmax(hl_output, dim=-1) top1 = t.argmax(ll_output, dim=-1) - accuracy = (top1[:, -1] == hl_output[:, -1]).float().mean() - IIA = accuracy.item() + accuracy = (top1[:, -1] == hl_output[:, -1]).float().mean().item() + IIA = accuracy # compute behavioral accuracy base_x, base_y = base_input[0:2] @@ -148,24 +160,30 @@ def run_eval_step( @staticmethod def _check_early_stop_fn( - test_metrics: list[MetricStore], - verbose=False, - non_ioi_thresh=0.65, - use_per_token_check=False, - ): + test_metrics: MetricStoreCollection, + verbose: bool = False, + non_ioi_thresh: float = 0.65, + use_per_token_check: bool = False, + ) -> bool: """ Early stopping for IOI """ - print_if_verbose = lambda x: print(x) if verbose else None # noqa: E731 + print_if_verbose: Callable[[t.Any], None] = lambda x: print(x) if verbose else None # noqa: E731 for metric in test_metrics: - if metric.get_name() == "val/IIA" and metric.get_value() < 100: - print_if_verbose(f"IIA is not enough: {metric.get_value()}") - return False - elif metric.get_name() == "val/strict_accuracy" and metric.get_value() < 100: - print_if_verbose(f"strict_accuracy is not enough: {metric.get_value()}") - return False + if metric.get_name() == "val/IIA": + val = metric.get_value() + if (isinstance(val, float) and val < 100) or isinstance(val, type(None)): + print_if_verbose(f"IIA is not enough: {metric.get_value()}") + return False + elif metric.get_name() == "val/strict_accuracy": + val = metric.get_value() + if (isinstance(val, float) and val < 100) or isinstance(val, type(None)): + print_if_verbose(f"strict_accuracy is not enough: {metric.get_value()}") + return False elif metric.get_name() == "val/per_token_accuracy": per_toke_acc = metric.get_value() + if not isinstance(per_toke_acc, list) and not isinstance(per_toke_acc, np.ndarray): + per_toke_acc = [per_toke_acc,] if per_toke_acc[-1] < 1: print_if_verbose( f"per_token_acc at IOI index is not enough: {per_toke_acc[-1]}" @@ -194,16 +212,12 @@ def _check_early_stop_fn( return False return True - def _check_early_stop_condition( - self, - *args, - **kwargs, - ): + def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool: if not self.training_args["next_token"]: - return super()._check_early_stop_condition(*args, **kwargs) + return super()._check_early_stop_condition(test_metrics) + return self._check_early_stop_fn( - *args, - **kwargs, - non_ioi_thresh=self.training_args["non_ioi_thresh"], - use_per_token_check=self.training_args["use_per_token_check"], - ) + test_metrics, + non_ioi_thresh=self.training_args["non_ioi_thresh"], + use_per_token_check=self.training_args["use_per_token_check"], + ) \ No newline at end of file diff --git a/iit/model_pairs/ll_model.py b/iit/model_pairs/ll_model.py index c230b6b..5df390a 100644 --- a/iit/model_pairs/ll_model.py +++ b/iit/model_pairs/ll_model.py @@ -1,4 +1,7 @@ +from typing import Optional, Callable + import torch as t +from torch import Tensor from transformer_lens import HookedTransformer from typing import Optional, Tuple from transformer_lens.hook_points import NamesFilter, HookPoint, HookedRootModule @@ -11,8 +14,9 @@ class LLModel: """ def __init__(self, model: HookedRootModule = None, - cfg: dict = None, - detach_while_caching=True): + cfg: Optional[dict] = None, + detach_while_caching: bool = True + ): assert model is not None or cfg is not None, "Either model or cfg must be provided." if model is None: model = HookedTransformer(cfg=cfg) @@ -23,7 +27,7 @@ def get_caching_hooks( self, names_filter: NamesFilter = None, incl_bwd: bool = False, - device=None, + device: Optional[t.device | str] = None, remove_batch_dim: bool = False, cache: Optional[dict] = None, ) -> Tuple[dict, list, list]: @@ -54,7 +58,7 @@ def get_caching_hooks( names_filter = lambda name: name in filter_list self.is_caching = True - def save_hook(tensor: t.Tensor, hook: HookPoint): + def save_hook(tensor: t.Tensor, hook: HookPoint) -> None: if self.detach_while_caching or (not (tensor.requires_grad and self.model.training)): # detach if the tensor requires grad and the model is not training tensor_to_cache = tensor.detach() @@ -73,7 +77,7 @@ def save_hook(tensor: t.Tensor, hook: HookPoint): else: cache[hook.name] = tensor_to_cache.to(device) - def save_hook_back(tensor, hook): + def save_hook_back(tensor: Tensor, hook: HookPoint) -> None: # we always detach here as loss.backward() was already called # and will throw an error if we don't do this tensor_to_cache = tensor.detach() @@ -93,22 +97,22 @@ def save_hook_back(tensor, hook): return cache, fwd_hooks, bwd_hooks @classmethod - def make_from_hooked_transformer(cls, hooked_transformer: HookedTransformer, detach_while_caching): + def make_from_hooked_transformer(cls, hooked_transformer: HookedTransformer, detach_while_caching: bool) -> "LLModel": ll_model = cls(hooked_transformer, detach_while_caching=detach_while_caching) ll_model.load_state_dict(hooked_transformer.state_dict()) return ll_model - def run_with_cache( + def run_with_cache( # type: ignore self, *model_args, - names_filter: NamesFilter = None, - device=None, - remove_batch_dim=False, - incl_bwd=False, - reset_hooks_end=True, - clear_contexts=False, + names_filter: Optional[NamesFilter] = None, + device: Optional[t.device | str] = None, + remove_batch_dim: bool = False, + incl_bwd: bool = False, + reset_hooks_end: bool = True, + clear_contexts: bool = False, **model_kwargs, - ): + ) -> Tuple[Tensor, ActivationCache]: """ Runs the model and returns the model output and a Cache object. @@ -153,7 +157,7 @@ def run_with_cache( ) return model_out, cache_dict - def __getattr__(self, name): + def __getattr__(self, name: str) -> t.Any: if name == "run_with_cache": return self.run_with_cache elif name == "get_caching_hooks": @@ -163,8 +167,8 @@ def __getattr__(self, name): def __call__(self, *args: t.Any, **kwds: t.Any) -> t.Any: return self.model(*args, **kwds) - def __repr__(self): + def __repr__(self) -> str: return self.model.__repr__() - def __str__(self): + def __str__(self) -> str: return self.model.__str__() \ No newline at end of file diff --git a/iit/model_pairs/probed_sequential_pair.py b/iit/model_pairs/probed_sequential_pair.py index 9ba699a..88c4bf1 100644 --- a/iit/model_pairs/probed_sequential_pair.py +++ b/iit/model_pairs/probed_sequential_pair.py @@ -1,50 +1,56 @@ -from iit.model_pairs.iit_model_pair import * -from iit.utils.probes import construct_probes +from typing import Callable, Type + +import wandb +import numpy as np +import torch as t +from torch import Tensor +from tqdm import tqdm # type: ignore +from transformer_lens.hook_points import HookedRootModule #type: ignore + +from iit.model_pairs.iit_model_pair import IITModelPair +from iit.utils.config import WANDB_ENTITY +from iit.utils.probes import construct_probes #type: ignore +from iit.utils.correspondence import Correspondence +from iit.utils.iit_dataset import IITDataset +from iit.model_pairs.ll_model import LLModel class IITProbeSequentialPair(IITModelPair): def __init__( self, - hl_model: HookedRootModule = None, - ll_model: HookedRootModule = None, - hl_graph=None, - corr: dict[HLNode, set[LLNode]] = {}, - seed=0, - training_args={}, + hl_model: HookedRootModule, + ll_model: HookedRootModule | LLModel, + corr: Correspondence, + training_args: dict = {}, ): - super().__init__(hl_model, ll_model, hl_graph, corr, seed, training_args) default_training_args = { - "batch_size": 256, - "lr": 0.001, - "num_workers": 0, "probe_weight": 1.0, } training_args = {**default_training_args, **training_args} - self.training_args = training_args + super().__init__(hl_model, ll_model, corr, training_args) - def run_train_step( + def run_train_step( # type: ignore self, - base_input, - ablation_input, - loss_fn, - optimizer, - probes, - probe_optimizer, - training_args, - ): + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], + loss_fn: Callable[[Tensor, Tensor], Tensor], + optimizer: t.optim.Optimizer, + probes: dict, + probe_optimizer: t.optim.Optimizer, + training_args: dict, + ) -> dict: ablation_loss = super().run_train_step( base_input, ablation_input, loss_fn, optimizer )["iit_loss"] # !!! Second forward pass # add probe losses and behavior loss - probe_losses = [] probe_optimizer.zero_grad() for p in probes.values(): p.train() base_x, base_y, base_intermediate_vars = base_input out, cache = self.ll_model.run_with_cache(base_x) - probe_loss = 0 + probe_loss = t.zeros(1) for hl_node_name in probes.keys(): gt = self.hl_model.get_idx_to_intermediate(hl_node_name)( base_intermediate_vars @@ -74,27 +80,36 @@ def run_train_step( def train( self, - dataset, - test_dataset, - epochs=1000, - use_wandb=False, - ): + train_set: IITDataset, + test_set: IITDataset, + optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam, + epochs: int = 1000, + use_wandb: bool = False, + wandb_name_suffix: str = "", + optimizer_kwargs: dict = {}, + ) -> None: training_args = self.training_args print(f"{training_args=}") # add to make probes - input_shape = (dataset[0][0][0]).unsqueeze(0).shape + input_shape = (train_set[0][0][0]).unsqueeze(0).shape with t.no_grad(): probes = construct_probes(self, input_shape) print("made probes", [(k, p.weight.shape) for k, p in probes.items()]) - loader, test_loader = self.make_loaders(dataset, test_dataset) + loader, test_loader = self.make_loaders( + train_set, + test_set, + training_args["batch_size"], + training_args["num_workers"] + ) params = list(self.ll_model.parameters()) for p in probes.values(): params += list(p.parameters()) - probe_optimizer = t.optim.Adam(params, lr=training_args["lr"]) + optimizer_kwargs['lr'] = training_args["lr"] + probe_optimizer = optimizer_cls(params, **optimizer_kwargs) - optimizer = t.optim.Adam(self.ll_model.parameters(), lr=training_args["lr"]) + optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs) loss_fn = t.nn.CrossEntropyLoss() if use_wandb and not wandb.run: @@ -155,6 +170,7 @@ def train( test_behavior_losses.append(behavior_loss.item()) behavior_accuracies.append(behavior_accuracy.item()) + probe_loss = t.zeros(1) for hl_node_name in probes.keys(): gt = self.hl_model.get_idx_to_intermediate(hl_node_name)( base_intermediate_vars diff --git a/iit/model_pairs/stop_grad_pair.py b/iit/model_pairs/stop_grad_pair.py index ad4cddb..c032df2 100644 --- a/iit/model_pairs/stop_grad_pair.py +++ b/iit/model_pairs/stop_grad_pair.py @@ -1,22 +1,26 @@ -from typing import Any -from iit.model_pairs.freeze_model_pair import FreezedModelPair -from transformer_lens.hook_points import HookPoint +from typing import Any, Callable + from torch import Tensor import torch +from transformer_lens.hook_points import HookPoint, HookedRootModule #type: ignore +from transformer_lens import HookedTransformer #type: ignore + +from iit.model_pairs.freeze_model_pair import FreezedModelPair +from iit.model_pairs.ll_model import LLModel from iit.utils.nodes import LLNode import iit.utils.node_picker as node_picker -from transformer_lens import HookedTransformer +from iit.utils.correspondence import Correspondence class StopGradHookedModel: def __init__( self, model: HookedTransformer, - params_not_in_circuit, - nodes_not_in_circuit, - post_nodes_not_in_circuit, - scale=1e6, - use_forward_hooks = True, + params_not_in_circuit: list[LLNode], + nodes_not_in_circuit: list[LLNode], + post_nodes_not_in_circuit: list[LLNode], + scale: float = 1e6, + use_forward_hooks: bool = True, ): self.model = model self.params_not_in_circuit = params_not_in_circuit @@ -38,8 +42,8 @@ def __getattr__(self, __name: str) -> Any: ) @staticmethod - def make_ln_hook(ll_node: LLNode, scale): - def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> torch.Tensor: + def make_ln_hook(ll_node: LLNode, scale: float) -> Callable[[Tensor, HookPoint], Tensor]: + def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> Tensor: return ( hook_point_out / scale ) # TODO: this won't work when individual heads are switched on/off @@ -47,10 +51,10 @@ def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> torch.Tensor: return hook_fn @staticmethod - def make_detached_hook(ll_node: LLNode): + def make_detached_hook(ll_node: LLNode) -> Callable[[Tensor, HookPoint], Tensor]: print(f"Attaching hook to {ll_node.name}") - def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> torch.Tensor: + def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> Tensor: act_idx = ll_node.get_index() hook_point_out[act_idx] = ( hook_point_out[act_idx].clone().detach() @@ -62,8 +66,8 @@ def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> torch.Tensor: return hook_fn - def make_zero_grad_hook(self, ll_node: LLNode): - def hook_fn(grad: Tensor, hook: HookPoint) -> torch.Tensor: + def make_zero_grad_hook(self, ll_node: LLNode) -> Callable[[Tensor, HookPoint], list[Tensor]]: + def hook_fn(grad: Tensor, hook: HookPoint) -> list[Tensor]: act_idx = ll_node.get_index() ori_grad_shape = grad.shape grad[act_idx] = torch.zeros_like(grad[act_idx]) @@ -77,7 +81,7 @@ def hook_fn(grad: Tensor, hook: HookPoint) -> torch.Tensor: return hook_fn - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: self.model.reset_hooks() if self.use_forward_hooks: return self.model.run_with_hooks( @@ -107,7 +111,13 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: class StopGradModelPair(FreezedModelPair): - def __init__(self, hl_model, ll_model, corr, training_args={}): + def __init__( + self, + hl_model: HookedRootModule, + ll_model: HookedRootModule | LLModel, + corr: Correspondence, + training_args: dict = {} + ): default_training_args = { "batch_size": 256, "lr": 0.001, @@ -133,7 +143,7 @@ def __init__(self, hl_model, ll_model, corr, training_args={}): post_nodes_not_in_circuit, scale=training_args["scale"], use_forward_hooks=training_args["use_ln_hooks"], - ) + ) #type: ignore self.wandb_method = "stop grads" # TODO: test another part of the model and see if the gradient changes after registering the hook diff --git a/iit/model_pairs/strict_iit_model_pair.py b/iit/model_pairs/strict_iit_model_pair.py index 619738f..9b7fc16 100644 --- a/iit/model_pairs/strict_iit_model_pair.py +++ b/iit/model_pairs/strict_iit_model_pair.py @@ -1,15 +1,25 @@ import numpy as np import torch as t +from torch import Tensor +from transformer_lens.hook_points import HookedRootModule #type: ignore -import iit.utils.node_picker as node_picker from iit.model_pairs.base_model_pair import Callable, Tensor from iit.model_pairs.iit_behavior_model_pair import IITBehaviorModelPair +from iit.model_pairs.ll_model import LLModel +import iit.utils.node_picker as node_picker from iit.utils.nodes import LLNode from iit.utils.metric import MetricStore, MetricStoreCollection, MetricType +from iit.utils.correspondence import Correspondence class StrictIITModelPair(IITBehaviorModelPair): - def __init__(self, hl_model, ll_model, corr, training_args={}): + def __init__( + self, + hl_model: HookedRootModule, + ll_model: HookedRootModule | LLModel, + corr: Correspondence, + training_args: dict = {} + ): default_training_args = { "batch_size": 256, "lr": 0.001, @@ -27,7 +37,7 @@ def __init__(self, hl_model, ll_model, corr, training_args={}): ) @staticmethod - def make_train_metrics(): + def make_train_metrics() -> MetricStoreCollection: return MetricStoreCollection( [ MetricStore("train/iit_loss", MetricType.LOSS), @@ -37,27 +47,23 @@ def make_train_metrics(): ) @staticmethod - def make_test_metrics(): + def make_test_metrics() -> MetricStoreCollection: return MetricStoreCollection( IITBehaviorModelPair.make_test_metrics().metrics + [MetricStore("val/strict_accuracy", MetricType.ACCURACY)], ) def sample_ll_node(self) -> LLNode: - return self.rng.choice(self.nodes_not_in_circuit) + return self.rng.choice(np.array(self.nodes_not_in_circuit, dtype=object)) def run_train_step( self, - base_input, - ablation_input, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], loss_fn: Callable[[Tensor, Tensor], Tensor], optimizer: t.optim.Optimizer, - ): + ) -> dict: use_single_loss = self.training_args["use_single_loss"] - iit_loss = 0 - ll_loss = 0 - behavior_loss = 0 - hl_node = self.sample_hl_name() # sample a high-level variable to ablate iit_loss = ( self.get_IIT_loss_over_batch(base_input, ablation_input, hl_node, loss_fn) @@ -102,7 +108,12 @@ def run_train_step( "train/strict_loss": ll_loss.item(), } - def run_eval_step(self, base_input, ablation_input, loss_fn: Callable[[Tensor, Tensor], Tensor]): + def run_eval_step( + self, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], + loss_fn: Callable[[Tensor, Tensor], Tensor] + ) -> dict: eval_returns = super().run_eval_step(base_input, ablation_input, loss_fn) base_x, base_y = base_input[0:2] ablation_x, ablation_y = ablation_input[0:2] @@ -127,7 +138,7 @@ def run_eval_step(self, base_input, ablation_input, loss_fn: Callable[[Tensor, T accuracies.append(accuracy) if len(accuracies) > 0: - accuracy = np.mean(accuracies) + accuracy = float(np.mean(accuracies)) else: accuracy = 1.0 @@ -135,7 +146,7 @@ def run_eval_step(self, base_input, ablation_input, loss_fn: Callable[[Tensor, T return eval_returns - def _check_early_stop_condition(self, test_metrics: list[MetricStore]): + def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool: metrics_to_check = [] for metric in test_metrics: if metric.get_name() == "val/strict_accuracy" and self.training_args["strict_weight"] > 0: @@ -144,4 +155,4 @@ def _check_early_stop_condition(self, test_metrics: list[MetricStore]): metrics_to_check.append(metric) if metric.get_name() == "val/IIA" and self.training_args["iit_weight"] > 0: metrics_to_check.append(metric) - return super()._check_early_stop_condition(metrics_to_check) \ No newline at end of file + return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check)) \ No newline at end of file diff --git a/iit/tasks/hl_model.py b/iit/tasks/hl_model.py index ea9c372..ee9c71f 100644 --- a/iit/tasks/hl_model.py +++ b/iit/tasks/hl_model.py @@ -4,5 +4,5 @@ class HLModel(ABC): @abstractmethod - def is_categorical(self): + def is_categorical(self) -> bool: pass \ No newline at end of file diff --git a/iit/tasks/ioi/__init__.py b/iit/tasks/ioi/__init__.py index 44709ae..d0f3595 100644 --- a/iit/tasks/ioi/__init__.py +++ b/iit/tasks/ioi/__init__.py @@ -18,31 +18,31 @@ } -def make_corr_dict(include_mlp=False, eval=False, use_pos_embed=False): +def make_corr_dict(include_mlp: bool = False, eval: bool = False, use_pos_embed: bool = False) -> dict: all_attns = [f"blocks.{i}.attn.hook_z" for i in range(ioi_cfg["n_layers"])] all_mlps = [f"blocks.{i}.mlp.hook_post" for i in range(ioi_cfg["n_layers"])] if eval: all_nodes_hook = "blocks.0.hook_resid_pre" if not use_pos_embed else "blocks.0.hook_pos_embed" return { - "hook_duplicate": [all_attns[1]], + "hook_duplicate": [[all_attns[1], Ix[[None]], None]], # "hook_previous": ["blocks.1.attn.hook_result"], - "hook_s_inhibition": [all_attns[2]], - "hook_name_mover": [all_attns[4]], + "hook_s_inhibition": [[all_attns[2], Ix[[None]], None]], + "hook_name_mover": [[all_attns[4], Ix[[None]], None]], "all_nodes_hook": ( - [all_nodes_hook, all_mlps[0]] + [[all_nodes_hook, Ix[[None]], None], [all_mlps[0], Ix[[None]], None]] if include_mlp - else [all_nodes_hook] + else [all_nodes_hook, Ix[[None]], None] ), - "hook_out": [f"blocks.{n_layers-1}.hook_resid_post"], + "hook_out": [[f"blocks.{n_layers-1}.hook_resid_post", Ix[[None]], None]], } ans = { - "hook_duplicate": [all_attns[1]], + "hook_duplicate": [[all_attns[1], Ix[[None]], None]], # "hook_previous": ["blocks.1.attn.hook_result"], - "hook_s_inhibition": [all_attns[2]], - "hook_name_mover": [all_attns[4]], + "hook_s_inhibition": [[all_attns[2], Ix[[None]], None]], + "hook_name_mover": [[all_attns[4], Ix[[None]], None]], } if include_mlp: - ans["all_nodes_hook"] = [all_mlps[0]] + ans["all_nodes_hook"] = [[all_mlps[0], Ix[[None]], None]] return ans @@ -69,8 +69,8 @@ def make_corr_dict(include_mlp=False, eval=False, use_pos_embed=False): ) -def make_ll_edges(corr: Correspondence): - def expand_nodes(ll_node: LLNode): +def make_ll_edges(corr: Correspondence) -> list[tuple[LLNode, LLNode]]: + def expand_nodes(ll_node: LLNode) -> list[LLNode]: ll_nodes_expanded = [] for head_index in range(n_heads): idx = Ix[:, :, head_index, :] @@ -85,8 +85,8 @@ def expand_nodes(ll_node: LLNode): hl_node_to = HLNode(edge[1], -1) ll_nodes_from = corr[hl_node_from] # set of LLNodes ll_nodes_to = corr[hl_node_to] - additional_from_nodes = set() - remove_from_nodes = set() + additional_from_nodes: set[LLNode] = set() + remove_from_nodes: set[LLNode] = set() for ll_node_from in ll_nodes_from: if "attn" in ll_node_from.name: ll_nodes_from_expanded = expand_nodes(ll_node_from) @@ -98,8 +98,8 @@ def expand_nodes(ll_node: LLNode): ll_nodes_from = ll_nodes_from | additional_from_nodes ll_nodes_from = ll_nodes_from - remove_from_nodes - additional_to_nodes = set() - remove_to_nodes = set() + additional_to_nodes: set[LLNode] = set() + remove_to_nodes: set[LLNode] = set() for ll_node_to in ll_nodes_to: if "attn" in ll_node_to.name: ll_nodes_to_expanded = expand_nodes(ll_node_to) diff --git a/iit/tasks/ioi/ioi_config.py b/iit/tasks/ioi/ioi_config.py index 267b143..2d07f37 100644 --- a/iit/tasks/ioi/ioi_config.py +++ b/iit/tasks/ioi/ioi_config.py @@ -18,6 +18,7 @@ import re import matplotlib.pyplot as plt import copy +from typing import Iterator, Optional NAMES = [ "Michael", @@ -259,7 +260,7 @@ ] -def multiple_replace(dict, text): +def multiple_replace(dict: dict, text: str) -> str: # from: https://stackoverflow.com/questions/15175142/how-can-i-do-multiple-substitutions-using-regex # Create a regular expression from the dictionary keys regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys()))) @@ -268,13 +269,13 @@ def multiple_replace(dict, text): return regex.sub(lambda mo: dict[mo.string[mo.start() : mo.end()]], text) -def iter_sample_fast(iterable, samplesize, seed): +def iter_sample_fast(iterator: Iterator, samplesize: int, seed: int) -> list: random.seed(seed) results = [] # Fill in the first samplesize elements: try: for _ in range(samplesize): - results.append(next(iterable)) + results.append(next(iterator)) except StopIteration: raise ValueError("Sample larger than population.") random.shuffle(results) # Randomize their positions @@ -286,8 +287,15 @@ def iter_sample_fast(iterable, samplesize, seed): def gen_prompt_uniform( - templates, names, nouns_dict, N, symmetric, prefixes=None, abc=False, seed=None, -): + templates: list[str], + names: list[str], + nouns_dict: dict, + N: int, + symmetric: bool, + prefixes: Optional[list[str]] = None, + abc: bool = False, + seed: Optional[int] = None, +) -> list[dict]: assert seed is not None random.seed(seed) @@ -346,7 +354,12 @@ def gen_prompt_uniform( return ioi_prompts -def gen_flipped_prompts(prompts, names, flip=("S2", "IO"), seed=None): +def gen_flipped_prompts( + prompts: list[dict], + names: list[str], + flip: tuple[str, str] = ("S2", "IO"), + seed: Optional[int] = None + ) -> list[dict]: """_summary_ Args: @@ -489,8 +502,13 @@ def gen_flipped_prompts(prompts, names, flip=("S2", "IO"), seed=None): # *Tok Idxs Methods -def get_name_idxs(prompts, tokenizer, idx_types=["IO", "S", "S2"], prepend_bos=False): - name_idx_dict = dict((idx_type, []) for idx_type in idx_types) +def get_name_idxs( + prompts: list[dict], + tokenizer: AutoTokenizer, + idx_types: list[str] = ["IO", "S", "S2"], + prepend_bos: bool = False + ) -> list[torch.Tensor]: + name_idx_dict: dict[str, list] = dict((idx_type, []) for idx_type in idx_types) double_s2 = False for prompt in prompts: t = prompt["text"].split(" ") @@ -519,7 +537,11 @@ def get_name_idxs(prompts, tokenizer, idx_types=["IO", "S", "S2"], prepend_bos=F ] -def get_word_idxs(prompts, word_list, tokenizer): +def get_word_idxs( + prompts: list[dict], + word_list: list[str], + tokenizer: AutoTokenizer + ) -> torch.Tensor: """Get the index of the words in word_list in the prompts. Exactly one of the word_list word has to be present in each prompt""" idxs = [] tokenized_words = [ @@ -548,7 +570,13 @@ def get_word_idxs(prompts, word_list, tokenizer): return torch.tensor(idxs) -def get_end_idxs(prompts, tokenizer, name_tok_len=1, prepend_bos=False, toks=None): +def get_end_idxs( + prompts: list[dict], + tokenizer: AutoTokenizer, + toks: torch.Tensor, + name_tok_len: int = 1, + prepend_bos: bool = False, + ) -> torch.Tensor: # toks = torch.Tensor(tokenizer([prompt["text"] for prompt in prompts], padding=True).input_ids).type(torch.int) relevant_idx = int(prepend_bos) # if the sentence begins with an end token @@ -603,7 +631,12 @@ def get_end_idxs(prompts, tokenizer, name_tok_len=1, prepend_bos=False, toks=Non ] # , "verb", "starts", "S-1", "punct"] # Kevin's antic averages -def get_idx_dict(ioi_prompts, tokenizer, prepend_bos=False, toks=None): +def get_idx_dict( + ioi_prompts: list[dict], + tokenizer: AutoTokenizer, + toks: torch.Tensor, + prepend_bos: bool = False, + ) -> dict: (IO_idxs, S_idxs, S2_idxs,) = get_name_idxs( ioi_prompts, tokenizer, @@ -647,7 +680,7 @@ def get_idx_dict(ioi_prompts, tokenizer, prepend_bos=False, toks=None): ] -def flip_prefixes(ioi_prompts): +def flip_prefixes(ioi_prompts: list[dict]) -> list[dict]: ioi_prompts = copy.deepcopy(ioi_prompts) for prompt in ioi_prompts: if prompt["text"].startswith("The "): @@ -661,7 +694,7 @@ def flip_prefixes(ioi_prompts): return ioi_prompts -def flip_names(ioi_prompts): +def flip_names(ioi_prompts: list[dict]) -> list[dict]: ioi_prompts = copy.deepcopy(ioi_prompts) for prompt in ioi_prompts: punct_idx = max( @@ -687,16 +720,16 @@ def __init__( prompt_type: Union[ str, List[str] ], # if list, then it will be a list of templates - N=500, - tokenizer=None, - prompts=None, - symmetric=False, - prefixes=None, - nb_templates=None, - ioi_prompts_for_word_idxs=None, - prepend_bos=False, - manual_word_idx=None, - seed=None, + N: int = 500, + tokenizer: Optional[AutoTokenizer] = None, + prompts: Optional[list[dict]] = None, + symmetric: bool = False, + prefixes: Optional[list[str]] = None, + nb_templates: Optional[int] = None, + ioi_prompts_for_word_idxs: Optional[list[dict]] = None, + prepend_bos: bool = False, + manual_word_idx: Optional[dict] = None, + seed: Optional[int] = None, ): """ ioi_prompts_for_word_idxs: @@ -710,7 +743,7 @@ def __init__( if not ( N == 1 or prepend_bos == False - or tokenizer.bos_token_id == tokenizer.eos_token_id + or (tokenizer is not None and (tokenizer.bos_token_id == tokenizer.eos_token_id)) ): warnings.warn( "Probably word_idx will be calculated incorrectly due to this formatting" @@ -767,7 +800,7 @@ def __init__( symmetric=symmetric, prefixes=self.prefixes, abc=(prompt_type in ["ABC", "ABC mixed", "BAC"]), - seed = (seed + 987654321) % 123456789, + seed = (seed + 987654321) % 123456789 if seed is not None else None, ) else: assert N == len(prompts), f"{N} and {len(prompts)}" @@ -840,7 +873,7 @@ def __init__( self.tokenizer.encode(" " + prompt["S"])[0] for prompt in self.ioi_prompts ] - self.tokenized_prompts = [] + self.tokenized_prompts: list = [] for i in range(self.N): self.tokenized_prompts.append( @@ -848,7 +881,12 @@ def __init__( ) @classmethod - def construct_from_ioi_prompts_metadata(cls, templates, ioi_prompts_data, **kwargs): + def construct_from_ioi_prompts_metadata( # type: ignore + cls, + templates: list[str], + ioi_prompts_data: list[dict], + **kwargs + ) -> "IOIDataset": """ Given a list of dictionaries (ioi_prompts_data) { @@ -874,7 +912,7 @@ def construct_from_ioi_prompts_metadata(cls, templates, ioi_prompts_data, **kwar # prompts[-1]["[OBJECT]"] = metadata["[OBJECT]"] return IOIDataset(prompt_type=templates, prompts=prompts, **kwargs) - def gen_flipped_prompts(self, flip, seed=None): + def gen_flipped_prompts(self, flip: tuple[str, str], seed: Optional[int] = None) -> "IOIDataset": """ Return a IOIDataset where the name to flip has been replaced by a random name. """ @@ -891,14 +929,14 @@ def gen_flipped_prompts(self, flip, seed=None): if flip in [("IO", "S1"), ("S", "IO")]: flipped_prompts = gen_flipped_prompts( self.ioi_prompts, - None, + [], flip, seed=(seed+12345)%9876, ) elif flip == ("S2", "IO"): flipped_prompts = gen_flipped_prompts( self.ioi_prompts, - None, + [], flip, seed=(seed+12345)%6543, ) @@ -927,7 +965,7 @@ def gen_flipped_prompts(self, flip, seed=None): ) return flipped_ioi_dataset - def copy(self): + def copy(self) -> "IOIDataset": copy_ioi_dataset = IOIDataset( prompt_type=self.prompt_type, N=self.N, @@ -940,8 +978,10 @@ def copy(self): ) return copy_ioi_dataset - def __getitem__(self, key): + def __getitem__(self, key: Union[int, slice]) -> "IOIDataset": sliced_prompts = self.ioi_prompts[key] + if isinstance(sliced_prompts, dict): + sliced_prompts = [sliced_prompts] sliced_dataset = IOIDataset( prompt_type=self.prompt_type, N=len(sliced_prompts), @@ -952,18 +992,15 @@ def __getitem__(self, key): ) return sliced_dataset - def __setitem__(self, key, value): + def __setitem__(self, key: torch.Any, value: torch.Any) -> None: raise NotImplementedError() - def __delitem__(self, key): + def __delitem__(self, key: torch.Any) -> None: raise NotImplementedError() - def __len__(self): + def __len__(self) -> int: return self.N - def tokenized_prompts(self): - return self.toks - # tests that the templates work as intended # assert len(BABA_EARLY_IOS) == len(BABA_LATE_IOS), (len(BABA_EARLY_IOS), len(BABA_LATE_IOS)) diff --git a/iit/tasks/ioi/ioi_dataset_tl.py b/iit/tasks/ioi/ioi_dataset_tl.py index 1cc6922..2fa6aed 100644 --- a/iit/tasks/ioi/ioi_dataset_tl.py +++ b/iit/tasks/ioi/ioi_dataset_tl.py @@ -9,16 +9,20 @@ from typing import Dict, List, Optional import einops -import torch +import torch as t +from torch import Tensor import tqdm.auto as tqdm from datasets import load_dataset from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer -from transformer_lens import utils +from transformer_lens import utils, HookedTransformer +from iit.utils.config import DEVICE +from iit.utils.iit_dataset import dataset_len # %% -def sanity_check(model): +def sanity_check(model: HookedTransformer) -> Tensor: """ Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong. @@ -31,7 +35,7 @@ def sanity_check(model): # %% -def make_wiki_data_loader(tokenizer, batch_size=8): +def make_wiki_data_loader(tokenizer: AutoTokenizer, batch_size: int = 8) -> DataLoader: """ Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.) @@ -46,7 +50,7 @@ def make_wiki_data_loader(tokenizer, batch_size=8): return data_loader -def make_owt_data_loader(tokenizer, batch_size=8): +def make_owt_data_loader(tokenizer: AutoTokenizer, batch_size: int = 8) -> DataLoader: """ Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma) @@ -61,7 +65,7 @@ def make_owt_data_loader(tokenizer, batch_size=8): return data_loader -def make_pile_data_loader(tokenizer, batch_size=8): +def make_pile_data_loader(tokenizer: AutoTokenizer, batch_size: int = 8) -> DataLoader: """ Evaluate on the first 10k texts from The Pile. @@ -77,7 +81,7 @@ def make_pile_data_loader(tokenizer, batch_size=8): return data_loader -def make_code_data_loader(tokenizer, batch_size=8): +def make_code_data_loader(tokenizer: AutoTokenizer, batch_size: int = 8) -> DataLoader: """ Evaluate on the CodeParrot dataset, a dump of Python code. @@ -105,8 +109,13 @@ def make_code_data_loader(tokenizer, batch_size=8): # %% -@torch.inference_mode() -def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"): +@t.inference_mode() +def evaluate_on_dataset( + model: HookedTransformer, + data_loader: DataLoader, + truncate: int = 100, + device: str = "cuda" + ) -> float: running_loss = 0 total = 0 for batch in tqdm.tqdm(data_loader): @@ -119,10 +128,15 @@ def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"): # %% -@torch.inference_mode() +@t.inference_mode() def induction_loss( - model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda" -): + model: HookedTransformer, + tokenizer: Optional[AutoTokenizer] = None, + batch_size: int = 4, + subseq_len: int = 384, + prepend_bos: Optional[bool] = None, + device: str = "cuda" +) -> Tensor: """ Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads. @@ -130,7 +144,7 @@ def induction_loss( whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this. """ # Make the repeated sequence - first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device) + first_half_tokens = t.randint(100, 20000, (batch_size, subseq_len)).to(device) repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)") # Use the provided prepend_bos as an override if it's not None; @@ -154,8 +168,13 @@ def induction_loss( # %% -@torch.inference_mode() -def evaluate(model, truncate=100, batch_size=8, tokenizer=None): +@t.inference_mode() +def evaluate( + model: HookedTransformer, + truncate: int = 100, + batch_size: int = 8, + tokenizer: Optional[AutoTokenizer] = None + ) -> Dict[str, float]: if tokenizer is None: tokenizer = model.tokenizer losses = {} @@ -201,7 +220,7 @@ class IOIDataset(Dataset): def __init__( self, - tokenizer, + tokenizer: AutoTokenizer, templates: Optional[List[str]] = None, names: Optional[List[str]] = None, nouns: Optional[Dict[str, List[str]]] = None, @@ -209,9 +228,11 @@ def __init__( symmetric: bool = False, prepend_bos: bool = True, seed: int = 42, + device: t.device = DEVICE, ): self.tokenizer = tokenizer self.prepend_bos = prepend_bos + self.device = device self.templates = ( templates if templates is not None else self.get_default_templates() @@ -227,10 +248,10 @@ def __init__( # If symmetric, get_sample will return two samples self.samples.extend(self.get_sample(symmetric=symmetric)) - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def __getitem__(self, idx, pad_token=False): + def __getitem__(self, idx: int, pad_token: bool = False) -> Dict[str, Tensor]: sample = self.samples[idx] prompt = self.tokenizer.encode(sample["text"]) if self.prepend_bos: @@ -241,13 +262,13 @@ def __getitem__(self, idx, pad_token=False): idx_to_ablate = len(prompt) - 2 return { - "prompt": torch.LongTensor(prompt), - "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])), - "S": torch.LongTensor(self.tokenizer.encode(sample["S"])), - "idx_to_ablate": idx_to_ablate, + "prompt": t.LongTensor(prompt), + "IO": t.LongTensor(self.tokenizer.encode(sample["IO"])), + "S": t.LongTensor(self.tokenizer.encode(sample["S"])), + "idx_to_ablate": t.LongTensor((idx_to_ablate,)), } - def get_sample(self, symmetric=False) -> List[Dict[str, str]]: + def get_sample(self, symmetric: bool = False) -> List[Dict[str, str]]: template: str = random.choice(self.templates) for noun_type, noun_list in self.nouns.items(): template = template.replace(f"[{noun_type}]", random.choice(noun_list)) @@ -271,11 +292,11 @@ def get_sample(self, symmetric=False) -> List[Dict[str, str]]: return samples @staticmethod - def get_default_names(): + def get_default_names() -> List[str]: return ["John", "Mary"] @staticmethod - def get_default_templates(): + def get_default_templates() -> List[str]: return [ "Then, [B] and [A] went to the [LOCATION]. [A] gave the [OBJECT] to [B]", "Then, [A] and [B] went to the [LOCATION]. [B] gave the [OBJECT] to [A]", @@ -284,17 +305,22 @@ def get_default_templates(): ] @staticmethod - def get_default_nouns(): + def get_default_nouns() -> Dict[str, List[str]]: return { "LOCATION": ["store", "market"], "OBJECT": ["milk", "eggs", "bread"], } -@torch.inference_mode() +@t.inference_mode() def ioi_eval( - model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False -): + model: HookedTransformer, + dataset: Optional[Dataset] = None, + batch_size: int = 8, + num_samples: int = 1000, + tokenizer: Optional[AutoTokenizer] = None, + symmetric: bool = False +) -> Dict[str, float]: """Evaluate the Model on the Indirect Object Identification Task. Args: @@ -314,9 +340,9 @@ def ioi_eval( if dataset is None: dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric) - def collate(samples): + def collate(samples: list[dict]) -> dict: #type: ignore prompts = [sample["prompt"] for sample in samples] - padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True) + padded_prompts = t.nn.utils.rnn.pad_sequence(prompts, batch_first=True) return { "prompt": padded_prompts, "IO": [sample["IO"] for sample in samples], @@ -344,7 +370,7 @@ def collate(samples): s = s[:min_len] # Remove identical prefixes - start_idx = torch.where(io != s)[0][0] + start_idx = t.where(io != s)[0][0] io = io[start_idx] s = s[start_idx] logit_idx = prefix_length + start_idx - 1 @@ -361,32 +387,27 @@ def collate(samples): total_logit_diff += logit_diff.item() return { - "Logit Difference": total_logit_diff / len(dataset), - "Accuracy": total_correct / len(dataset), + "Logit Difference": total_logit_diff / dataset_len(dataset), + "Accuracy": total_correct / dataset_len(dataset), } -from iit.utils.config import DEVICE class IOIDatasetWrapper(IOIDataset): - def __init__(self, *args, - device=DEVICE, **kwargs): - super().__init__(*args, **kwargs) - self.device = device - def get_inputs(self): + def get_inputs(self) -> Tensor: items = [self.__getitem__(i) for i in range(len(self))] inputs = [item[0] for item in items] - inputs_tensor = torch.stack(inputs) + inputs_tensor = t.stack(inputs) return inputs_tensor - def get_targets(self): + def get_targets(self) -> list[Tensor]: items = [self.__getitem__(i) for i in range(len(self))] targets = [item[1] for item in items] return targets - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: # type: ignore x = super().__getitem__(idx) prompt = x['prompt'] - y = list(prompt[1:]) - y = torch.nn.functional.one_hot(torch.tensor(y), num_classes=self.tokenizer.vocab_size).float() + y_list = list(prompt[1:]) + y = t.nn.functional.one_hot(t.tensor(y_list), num_classes=self.tokenizer.vocab_size).float() return (x['prompt'][:-1].to(self.device), (y).to(self.device), (x['IO']).to(self.device)) \ No newline at end of file diff --git a/iit/tasks/ioi/ioi_hl.py b/iit/tasks/ioi/ioi_hl.py index 268c79f..7f033ae 100644 --- a/iit/tasks/ioi/ioi_hl.py +++ b/iit/tasks/ioi/ioi_hl.py @@ -1,11 +1,14 @@ +from typing import Callable + import torch as t +from torch import Tensor from transformer_lens.hook_points import HookedRootModule, HookPoint from iit.tasks.hl_model import HLModel class DuplicateHead(t.nn.Module): - def forward(self, tokens:t.Tensor): + def forward(self, tokens : Tensor) -> Tensor: # Write the last previous position of any duplicated token (used at S2) positions = (tokens[..., None, :] == tokens[..., :, None]) # batch seq1 seq2 positions = t.triu(positions, diagonal=1) # only consider positions before this one @@ -15,7 +18,7 @@ def forward(self, tokens:t.Tensor): return ret class PreviousHead(t.nn.Module): - def forward(self, tokens:t.Tensor): + def forward(self, tokens: Tensor) -> Tensor: # copy token S1 to token S1+1 (used at S1+1) ret = t.full_like(tokens, -1) ret[..., 1:] = tokens[..., :-1] @@ -26,7 +29,7 @@ class InductionHead(t.nn.Module): class SInhibitionHead(t.nn.Module): - def forward(self, tokens: t.Tensor, duplicate: t.Tensor): + def forward(self, tokens: Tensor, duplicate: Tensor) -> Tensor: """ when duplicate is not -1, output a flag to the name mover head to NOT copy this name @@ -40,8 +43,8 @@ def forward(self, tokens: t.Tensor, duplicate: t.Tensor): # extract token positions we care about from duplicate duplicate_pos_at_duplicates = t.where(duplicate != -1) duplicate_pos_at_tokens = duplicate[duplicate_pos_at_duplicates[0], duplicate_pos_at_duplicates[1]] - duplicate_pos_at_tokens = (duplicate_pos_at_duplicates[0], duplicate_pos_at_tokens) - duplicate_tokens = tokens[duplicate_pos_at_tokens] + duplicate_pos_at_tokens_tup = (duplicate_pos_at_duplicates[0], duplicate_pos_at_tokens) + duplicate_tokens = tokens[duplicate_pos_at_tokens_tup] assert ret[duplicate_pos_at_duplicates].abs().sum() == 0 # sanity check, to make sure we're not overwriting anything # replace ret with the duplicated tokens ret[duplicate_pos_at_duplicates] = duplicate_tokens @@ -49,12 +52,12 @@ def forward(self, tokens: t.Tensor, duplicate: t.Tensor): return ret class NameMoverHead(t.nn.Module): - def __init__(self, names, d_vocab:int=40, ): + def __init__(self, names: Tensor, d_vocab : int=40): super().__init__() self.d_vocab_out = d_vocab self.names = names - def forward(self, tokens: t.Tensor, s_inhibition: t.Tensor): + def forward(self, tokens: Tensor, s_inhibition: Tensor) -> Tensor: """ increase logit of all names in the sentence, except those flagged by s_inhibition """ @@ -84,7 +87,7 @@ class IOI_HL(HookedRootModule, HLModel): - S-inhibition heads: Inhibit attention of Name Mover Heads to S1 and S2 tokens - Name mover heads: Copy all previous names in the sentence """ - def __init__(self, d_vocab, names): + def __init__(self, d_vocab: int, names: Tensor): super().__init__() self.all_nodes_hook = HookPoint() self.duplicate_head = DuplicateHead() @@ -99,12 +102,12 @@ def __init__(self, d_vocab, names): self.d_vocab = d_vocab self.setup() - def is_categorical(self): + def is_categorical(self) -> bool: return True - def forward(self, args, verbose=False): - show = print if verbose else lambda *args, **kwargs: None - if isinstance(args, t.Tensor): + def forward(self, args: Tensor | tuple, verbose: bool = False) -> Tensor: + show: Callable[[t.Any], None] = lambda *args, **kwargs: print(*args, **kwargs) if verbose else None + if isinstance(args, Tensor): input = args elif isinstance(args, tuple): input = args[0] diff --git a/iit/tasks/ioi/test_ioi.py b/iit/tasks/ioi/test_ioi.py index b342b16..4088a95 100644 --- a/iit/tasks/ioi/test_ioi.py +++ b/iit/tasks/ioi/test_ioi.py @@ -1,17 +1,20 @@ +from typing import Callable import torch as t +from torch import Tensor from .ioi_hl import DuplicateHead, PreviousHead, SInhibitionHead, NameMoverHead, IOI_HL from transformer_lens.hook_points import HookPoint +from transformer_lens import ActivationCache import numpy as np IOI_TEST_NAMES = t.tensor([10, 20, 30]) -def nonzero_values(a: t.Tensor): +def nonzero_values(a: Tensor) -> Tensor: return t.cat((a.nonzero(), a[a != 0][:, None]), dim=-1) -def make_hook(corrupted_cache, hook_name): - def hook_fn(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor: +def make_hook(corrupted_cache: ActivationCache, hook_name: str) -> Callable[[Tensor, HookPoint], Tensor]: + def hook_fn(hook_point_out: Tensor, hook: HookPoint) -> Tensor: out = hook_point_out.clone() out = corrupted_cache[hook_name] return out @@ -19,17 +22,17 @@ def hook_fn(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor: return hook_fn -def test_duplicate_head(): +def test_duplicate_head() -> None: a = DuplicateHead()(t.tensor([[3, 1, 4, 1, 5, 9, 2, 6, 5]])) assert a.equal(t.tensor([[-1, -1, -1, 1, -1, -1, -1, -1, 4]])) -def test_previous_head(): +def test_previous_head() -> None: a = PreviousHead()(t.tensor([[3, 1, 4, 1, 5, 9, 2, 6, 5]])) assert a.equal(t.tensor([[-1, 3, 1, 4, 1, 5, 9, 2, 6]])) -def test_s_inhibition_head(): +def test_s_inhibition_head() -> None: a = SInhibitionHead()( t.tensor([[3, 1, 4, 1, 5, 9, 2, 6, 5]]), t.tensor([[-1, -1, -1, 1, -1, -1, -1, -1, 4]]), @@ -37,7 +40,7 @@ def test_s_inhibition_head(): assert a.equal(t.tensor([[-1, -1, -1, 1, -1, -1, -1, -1, 5]])) -def test_name_mover_head(): +def test_name_mover_head() -> None: a = NameMoverHead(IOI_TEST_NAMES, d_vocab=21)( t.tensor([[1, 2, 10, 20]]), t.tensor([[-1, 20, 10, -1]]) ) @@ -55,7 +58,7 @@ def test_name_mover_head(): ) -def test_ioi_hl(): +def test_ioi_hl() -> None: a = IOI_HL(d_vocab=21, names=IOI_TEST_NAMES)( (t.tensor([[3, 10, 4, 10, 5, 9, 2, 6, 5]]), None, None) ) @@ -76,7 +79,7 @@ def test_ioi_hl(): ) -def test_duplicate_head_patching(): +def test_duplicate_head_patching() -> None: test_names = t.tensor(range(10, 60, 1)) hl_model = IOI_HL(d_vocab=61, names=test_names) @@ -133,7 +136,7 @@ def test_duplicate_head_patching(): -def test_all_nodes_patching(): +def test_all_nodes_patching() -> None: hl_model = IOI_HL(d_vocab=21, names=IOI_TEST_NAMES) p_clean = t.tensor( [[1, 2, IOI_TEST_NAMES[0], 3, 4, IOI_TEST_NAMES[1], 5, 6, IOI_TEST_NAMES[0]]] @@ -154,7 +157,7 @@ def test_all_nodes_patching(): -def test_s_inhibition_head_patching(): +def test_s_inhibition_head_patching() -> None: return # Not implemented yet test_names = t.tensor(range(10, 60, 1)) @@ -204,4 +207,4 @@ def test_s_inhibition_head_patching(): print("Different for ", prompt_type[j], " -> ", prompt_type[i]) else: print("Same for ", prompt_type[j], " -> ", prompt_type[i]) - \ No newline at end of file + diff --git a/iit/tasks/ioi/utils.py b/iit/tasks/ioi/utils.py index a88a17b..e80c9a4 100644 --- a/iit/tasks/ioi/utils.py +++ b/iit/tasks/ioi/utils.py @@ -1,7 +1,17 @@ -from .ioi_dataset_tl import * -from .ioi_hl import * +import torch as t -def make_ioi_dataset_and_hl(num_samples, ll_model, NAMES, device=DEVICE, verbose=False): +from iit.model_pairs.ll_model import LLModel +from iit.utils.config import DEVICE +from .ioi_hl import IOI_HL +from .ioi_dataset_tl import IOIDataset, IOIDatasetWrapper + +def make_ioi_dataset_and_hl( + num_samples: int, + ll_model: LLModel, + NAMES: list[str], + device: t.device = DEVICE, + verbose: bool = False + ) -> tuple[IOIDatasetWrapper, IOI_HL]: ioi_dataset_tl = IOIDataset( num_samples=num_samples, tokenizer=ll_model.tokenizer, diff --git a/iit/tasks/mnist_pvr/dataset.py b/iit/tasks/mnist_pvr/dataset.py index 98c8fba..539e54e 100644 --- a/iit/tasks/mnist_pvr/dataset.py +++ b/iit/tasks/mnist_pvr/dataset.py @@ -1,10 +1,14 @@ +from typing import Optional + import torch as t +from torch import Tensor import torchvision +import torchvision.datasets as datasets from torch.utils.data import Dataset from PIL import Image, ImageOps import numpy as np from .utils import * -from iit.utils.index import Ix, Index +from iit.utils.index import Ix, TorchIndex from iit.utils.nodes import HLNode @@ -17,21 +21,21 @@ class ImagePVRDataset(Dataset): def __init__( self, - base_dataset, + base_dataset: datasets.MNIST, class_map: dict[int, int] = MNIST_CLASS_MAP, - seed=0, - use_cache=True, - length=200000, - iid=True, - pad_size=0, - unique_per_quad=False, + seed: int = 0, + use_cache: bool = True, + length: int = 200000, + iid: bool = True, + pad_size: int = 0, + unique_per_quad: bool = False, ): self.base_dataset = base_dataset self.class_map = class_map self.seed = seed self.rng = np.random.default_rng(seed) assert all(v in {1, 2, 3} for v in class_map.values()) - self.cache = {} + self.cache: dict[int, tuple[Tensor, Tensor, Tensor]] = {} self.use_cache = False self.length = length self.iid = iid @@ -43,17 +47,17 @@ def __init__( assert ( len(self.base_dataset) >= 4 * self.length ), "Dataset is too small for non-iid mode" - self.input_shape = None + self.input_shape: Optional[t.Size] = None self.set_input_shape(self[0][0].unsqueeze(0).shape) - def set_input_shape(self, shape): + def set_input_shape(self, shape: t.Size) -> None: self.input_shape = shape - def get_input_shape(self): + def get_input_shape(self) -> None | t.Size: return self.input_shape @staticmethod - def concatenate_2x2(images): + def concatenate_2x2(images: list[Image.Image]) -> Image.Image: """ Concatenates four PIL.Image.Image objects into a 2x2 square. """ @@ -68,15 +72,15 @@ def concatenate_2x2(images): return new_image - def make_label_from_intermediate(self, intermediate_vars): + def make_label_from_intermediate(self, intermediate_vars: Tensor) -> Tensor: """ Returns the label for the new image based on the intermediate variables. """ - pointer = self.class_map[intermediate_vars[0].item()] + pointer = self.class_map[int(intermediate_vars[0].item())] new_label = t.tensor(intermediate_vars[pointer].item()) return new_label - def __getitem__(self, index): + def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor]: if index in self.cache and self.use_cache: return self.cache[index] if self.iid: @@ -104,7 +108,7 @@ def __getitem__(self, index): ] # print(f"Padding images by {self.pad_size}") new_image = self.concatenate_2x2(images) - new_image = torchvision.transforms.functional.to_tensor(new_image) + new_image_output = torchvision.transforms.functional.to_tensor(new_image) base_label = base_items[0][1] pointer = self.class_map[base_label] @@ -114,21 +118,21 @@ def __getitem__(self, index): assert ( new_label == new_label_from_func ), f"new_label: {new_label}; new_label_from_func: {new_label_from_func}" - ret = new_image, new_label, intermediate_vars + ret = new_image_output, new_label, intermediate_vars if self.use_cache: self.cache[index] = ret return ret - def __len__(self): + def __len__(self) -> int: return self.length def patch_at_hl_idx( self, - input: t.Tensor, - intermediate_var: t.Tensor, - idx: Index, + input: Tensor, + intermediate_var: Tensor, + idx: TorchIndex, idx_to_intermediate: int, - ): + ) -> tuple[Tensor, Tensor, Tensor]: """ Patches the input and label to be compatible with the PVR model. """ @@ -158,7 +162,7 @@ def patch_at_hl_idx( def patch_batch_at_hl( self, batch: list, intermediate_vars: list, hl_node: HLNode - ): + ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: """ Patches the input and label to be compatible with the PVR model. """ @@ -175,9 +179,12 @@ def patch_batch_at_hl( new_labels.append(new_label) return new_batch, new_labels, new_intermediate_vars - def get_idx_and_intermediate(self, hl_node: HLNode): + def get_idx_and_intermediate(self, hl_node: HLNode) -> tuple[TorchIndex, int]: input_shape = self.get_input_shape() - width, height = input_shape[2], input_shape[3] + if isinstance(input_shape, t.Size): + width, height = input_shape[2], input_shape[3] + else: + raise ValueError("Cannot obtain input shape from base dataset.") if "hook_tl" in hl_node.name: idx = Ix[None, : width // 2, : height // 2] idx_to_intermediate = 0 diff --git a/iit/tasks/mnist_pvr/get_alignment.py b/iit/tasks/mnist_pvr/get_alignment.py index d653dd9..f2aab81 100644 --- a/iit/tasks/mnist_pvr/get_alignment.py +++ b/iit/tasks/mnist_pvr/get_alignment.py @@ -1,17 +1,20 @@ +from transformer_lens.hook_points import HookedRootModule + from .pvr_hl import get_corr, MNIST_PVR_HL from .pvr_check_leaky_hl import get_corr as get_corr_leaky, MNIST_PVR_Leaky_HL import torchvision from iit.utils.config import DEVICE import torch as t from iit.utils.wrapper import HookedModuleWrapper +from iit.utils.correspondence import Correspondence -def get_alignment(config, task): +def get_alignment(config: dict, task: str) -> tuple[HookedModuleWrapper, HookedRootModule, Correspondence]: if config["model"] == "resnet18": resnet18 = torchvision.models.resnet18().to(DEVICE) # 11M parameters resnet18.fc = t.nn.Linear(512, 10).to(DEVICE) ll_model = HookedModuleWrapper( - resnet18, name="resnet18", recursive=True, hook_self=False + resnet18, name="resnet18", recursive=True, get_hook_self=False ).to(DEVICE) else: raise ValueError(f"Unknown model {config['model']}") diff --git a/iit/tasks/mnist_pvr/pvr_check_leaky_hl.py b/iit/tasks/mnist_pvr/pvr_check_leaky_hl.py index 941daf6..97921aa 100644 --- a/iit/tasks/mnist_pvr/pvr_check_leaky_hl.py +++ b/iit/tasks/mnist_pvr/pvr_check_leaky_hl.py @@ -1,13 +1,16 @@ +from typing import Callable import torch as t +from torch import Tensor from transformer_lens.hook_points import HookedRootModule, HookPoint from iit.utils.config import DEVICE -from .utils import * +from .utils import MNIST_CLASS_MAP from iit.utils.index import Ix -from iit.utils.nodes import HLNode, LLNode +from iit.utils.nodes import HLNode, LLNode, HookName +from iit.utils.correspondence import Correspondence class MNIST_PVR_Leaky_HL(HookedRootModule): - def __init__(self, class_map=MNIST_CLASS_MAP, device=DEVICE): + def __init__(self, class_map: dict = MNIST_CLASS_MAP, device: t.device = DEVICE): super().__init__() hook_str = """hook_{}_leaked_to_{}""" self.leaky_hooks = {} @@ -19,7 +22,7 @@ def __init__(self, class_map=MNIST_CLASS_MAP, device=DEVICE): for i in ["tl", "tr", "bl", "br"]: for j in ["tl", "tr", "bl", "br"]: if i != j: - hl_node = HLNode(hook_str.format(i, j), 10, None) + hl_node = HLNode(hook_str.format(i, j), 10, Ix[[None]]) self.leaky_hooks[hl_node] = HookPoint() setattr( self, hl_node.name, self.leaky_hooks[hl_node] @@ -29,7 +32,7 @@ def __init__(self, class_map=MNIST_CLASS_MAP, device=DEVICE): ) self.setup() - def get_idx_to_intermediate(self, name: str): + def get_idx_to_intermediate(self, name: str) -> Callable[[Tensor], Tensor]: if "hook_tl" in name: return lambda intermediate_vars: intermediate_vars[:, 0] if "hook_tr" in name: @@ -41,8 +44,8 @@ def get_idx_to_intermediate(self, name: str): else: raise ValueError(f"Hook name {name} not recognised") - def forward(self, args): - input, label, intermediate_data = args + def forward(self, args: tuple[t.Any, t.Any, Tensor]) -> Tensor: + _, _, intermediate_data = args # print([a.shape for a in args]) tl, tr, bl, br = [intermediate_data[:, i] for i in range(4)] # print(f"intermediate_data is a {type(intermediate_data)}; tl is a {type(tl)}") @@ -68,7 +71,7 @@ def forward(self, args): hl = MNIST_PVR_Leaky_HL().to(DEVICE) -def get_corr(mode, hook_point, model, input_shape): +def get_corr(mode: str, hook_point: str, model: HookedRootModule, input_shape: tuple[int, int, int, int]) -> Correspondence: with t.no_grad(): out, cache = model.run_with_cache(t.zeros(input_shape, device=DEVICE)) input_shape = cache[hook_point].shape @@ -104,5 +107,5 @@ def get_corr(mode, hook_point, model, input_shape): corr[k] = {LLNode(name=hook_point, index=br_idx)} else: print(f"!!!!!! Skipping {k}") - return corr + return Correspondence(corr) raise NotImplementedError(mode) diff --git a/iit/tasks/mnist_pvr/pvr_hl.py b/iit/tasks/mnist_pvr/pvr_hl.py index 965e3ae..a98fb5c 100644 --- a/iit/tasks/mnist_pvr/pvr_hl.py +++ b/iit/tasks/mnist_pvr/pvr_hl.py @@ -1,18 +1,20 @@ -from iit.utils.nodes import HookName +from typing import Callable import torch as t +from torch import Tensor from transformer_lens.hook_points import HookedRootModule, HookPoint from iit.utils.config import DEVICE from iit.utils.nodes import HLNode, LLNode from iit.utils.index import Ix -from .utils import * +from .utils import MNIST_CLASS_MAP from iit.tasks.hl_model import HLModel +from iit.utils.correspondence import Correspondence class MNIST_PVR_HL(HookedRootModule, HLModel): """ A high-level implementation of the algorithm used for MNIST_PVR """ - def __init__(self, class_map=MNIST_CLASS_MAP, device=DEVICE): + def __init__(self, class_map: dict = MNIST_CLASS_MAP, device: t.device = DEVICE): super().__init__() self.hook_tl = HookPoint() self.hook_tr = HookPoint() @@ -23,13 +25,13 @@ def __init__(self, class_map=MNIST_CLASS_MAP, device=DEVICE): ) self.setup() - def uses_intermediate_variables(self): + def uses_intermediate_variables(self) -> bool: return True - def is_categorical(self): + def is_categorical(self) -> bool: return True - def get_idx_to_intermediate(self, name: HookName): + def get_idx_to_intermediate(self, name: str) -> Callable[[Tensor], Tensor]: """ Returns a function that takes in a list of intermediate variables and returns the index of the one to use. """ @@ -44,8 +46,8 @@ def get_idx_to_intermediate(self, name: HookName): else: raise NotImplementedError(name) - def forward(self, args): - input, label, intermediate_data = args + def forward(self, args: tuple[t.Any, t.Any, Tensor]) -> Tensor: + _, _, intermediate_data = args # print([a.shape for a in args]) tl, tr, bl, br = [intermediate_data[:, i] for i in range(4)] # print(f"intermediate_data is a {type(intermediate_data)}; tl is a {type(tl)}") @@ -61,14 +63,14 @@ def forward(self, args): # %% hl_nodes = { - "hook_tl": HLNode("hook_tl", 10, None), - "hook_tr": HLNode("hook_tr", 10, None), - "hook_bl": HLNode("hook_bl", 10, None), - "hook_br": HLNode("hook_br", 10, None), + "hook_tl": HLNode("hook_tl", 10, Ix[[None]]), + "hook_tr": HLNode("hook_tr", 10, Ix[[None]]), + "hook_bl": HLNode("hook_bl", 10, Ix[[None]]), + "hook_br": HLNode("hook_br", 10, Ix[[None]]), } -def get_corr(mode, hook_point, model: HookedRootModule, input_shape): +def get_corr(mode: str, hook_point: str, model: HookedRootModule, input_shape: tuple[int, int, int, int]) -> Correspondence: with t.no_grad(): out, cache = model.run_with_cache(t.zeros(input_shape).to(DEVICE)) # print(out.shape) @@ -134,4 +136,4 @@ def get_corr(mode, hook_point, model: HookedRootModule, input_shape): ) }, } - return corr + return Correspondence(corr) diff --git a/iit/tasks/mnist_pvr/utils.py b/iit/tasks/mnist_pvr/utils.py index 2917273..28c43f8 100644 --- a/iit/tasks/mnist_pvr/utils.py +++ b/iit/tasks/mnist_pvr/utils.py @@ -1,16 +1,17 @@ import torchvision.datasets as datasets import torch as t +from torch import Tensor import torchvision MNIST_CLASS_MAP = {k: [1, 1, 1, 1, 2, 2, 2, 3, 3, 3][k] for k in range(10)} mnist_size = 28 -def make_mnist_dataset(): +def make_mnist_dataset() -> tuple[datasets.MNIST, datasets.MNIST]: mnist_train = datasets.MNIST("./data", download=True) mnist_test = datasets.MNIST("./data", train=False, download=True) return mnist_train, mnist_test, -def visualize_datapoint(dataset, index): +def visualize_datapoint(dataset: datasets.MNIST, index: int) -> None: image, label, intermediate_vars = dataset[index] print(f"Label: {label}") print(f"Intermediate vars: {intermediate_vars}") @@ -19,6 +20,6 @@ def visualize_datapoint(dataset, index): image.show() -def visualize_image(input): +def visualize_image(input: Tensor) -> None: image = torchvision.transforms.functional.to_pil_image(input) image.show() diff --git a/iit/tasks/task_loader.py b/iit/tasks/task_loader.py index 560220d..881c57b 100644 --- a/iit/tasks/task_loader.py +++ b/iit/tasks/task_loader.py @@ -1,8 +1,12 @@ +from transformer_lens.hook_points import HookedRootModule # type: ignore + from .mnist_pvr.dataset import ImagePVRDataset from .mnist_pvr.utils import make_mnist_dataset from .mnist_pvr.get_alignment import get_alignment as get_mnist_pvr_corr -from transformer_lens.hook_points import HookedRootModule from iit.utils.iit_dataset import IITDataset +from iit.utils.correspondence import Correspondence +from iit.utils.wrapper import HookedModuleWrapper +from iit.tasks.hl_model import HLModel def get_dataset( @@ -39,7 +43,7 @@ def get_dataset( return IITDataset(train_set, train_set), IITDataset(test_set, test_set) -def get_alignment(task: str, config: dict = {}): +def get_alignment(task: str, config: dict = {}) -> tuple[HookedModuleWrapper | None, HookedRootModule | None, Correspondence]: if "pvr" in task: default_config = { "mode": "q", @@ -51,13 +55,8 @@ def get_alignment(task: str, config: dict = {}): return get_mnist_pvr_corr(default_config, task) if "ioi" in task: from .ioi import corr - return corr + return None, None, corr raise ValueError(f"Unknown task {task}") -def get_default_corr(task: str) -> dict: - if "pvr" in task: - return get_alignment(task)[-1] - elif "ioi" in task: - from .ioi import corr_dict - return corr_dict - raise ValueError(f"Unknown task {task}") +def get_default_corr(task: str) -> Correspondence: + return get_alignment(task)[-1] diff --git a/iit/utils/argparsing.py b/iit/utils/argparsing.py new file mode 100644 index 0000000..3ac44c3 --- /dev/null +++ b/iit/utils/argparsing.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +import torch as t + +@dataclass +class IOIArgParseNamespace: + # global + output_dir: str = "./results" + include_mlp: bool = False + use_wandb: bool = False + num_samples: int = 18000 + device: str = "cuda" if t.cuda.is_available() else "cpu" + batch_size: int = 512 + next_token: bool = False + + # eval + weights: str = "100_100_40" + mean: bool = True + load_from_wandb: bool = False + + # train + epochs: int = 1000 + lr: float = 1e-3 + iit: float = 1.0 # iit loss weight + b: float = 1.0 # baseline loss weight + s: float = 0.4 # siit loss weight + clip_grad_norm: float = 1.0 + use_single_loss: bool = False + save_to_wandb: bool = False + \ No newline at end of file diff --git a/iit/utils/config.py b/iit/utils/config.py index 96ac934..e643053 100644 --- a/iit/utils/config.py +++ b/iit/utils/config.py @@ -1,4 +1,4 @@ import torch DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -WANDB_ENTITY = "cybershiptrooper" \ No newline at end of file +WANDB_ENTITY = "cybershiptrooper" #TODO: This should be editable by the user at runtime \ No newline at end of file diff --git a/iit/utils/correspondence.py b/iit/utils/correspondence.py index 5767457..3894210 100644 --- a/iit/utils/correspondence.py +++ b/iit/utils/correspondence.py @@ -1,22 +1,27 @@ +from typing import Optional, Any import pickle from iit.utils.nodes import HLNode, LLNode +from iit.utils.index import Ix, TorchIndex class Correspondence(dict[HLNode, set[LLNode]]): - def __init__( + def __init__( # type: ignore self, *args, - suffixes={"attn": "attn.hook_result", "mlp": "mlp.hook_post"}, + suffixes: dict = {"attn": "attn.hook_result", "mlp": "mlp.hook_post"}, **kwargs, ): super().__init__(*args, **kwargs) self.suffixes = suffixes - def __setattr__(self, key, value): - if key == "suffixes": + def __setattr__(self, key: HLNode | str, value: set[LLNode] | dict[str, str]) -> None: # type: ignore + if isinstance(key, str): + if key != "suffixes": + raise ValueError(f"Key must be an HLNode or 'suffixes', got {key}") assert isinstance(value, dict), ValueError( f"__value is not a dict, but {type(value)}" ) + assert isinstance(value, dict), ValueError(f"suffixes value is not a dict, but {type(value)}") else: assert isinstance(key, HLNode), "key must be of type HLNode, got %s" % type( key @@ -28,48 +33,59 @@ def __setattr__(self, key, value): "__value contains non-LLNode elements" ) # print(self.keys(), self.values()) - super().__setattr__(key, value) + super().__setattr__(key, value) # type: ignore - def get_suffixes(self): + def get_suffixes(self) -> dict: return self.suffixes @staticmethod - def get_hook_suffix(corr: dict[HLNode, set[LLNode]]) -> dict[str, str]: - suffixes = {} - for hl_node, ll_nodes in corr.items(): + def get_hook_suffix(corr: dict[HLNode, set[LLNode]] | dict[str, list[tuple[str, TorchIndex, Any]]]) -> dict[str, str]: + suffixes: dict[str, str] = {} + for _, ll_nodes in corr.items(): for ll_node in ll_nodes: # add everything after 'blocks..' to the set - suffix = ll_node.name.split(".")[2:] - suffix = ".".join(suffix) - if "attn" in ll_node.name: + if isinstance(ll_node, LLNode): + ll_node_name = ll_node.name + else: + ll_node_name = ll_node[0] + suffix_pieces = ll_node_name.split(".")[2:] + suffix = ".".join(suffix_pieces) + if "attn" in ll_node_name: if "attn" in suffixes and suffixes["attn"] != suffix: raise ValueError( f"Multiple attn suffixes found: {suffixes['attn']} and {suffix}, multiple attn hook locations are not supported yet." ) suffixes["attn"] = suffix - elif "mlp" in ll_node.name: + elif "mlp" in ll_node_name: if "mlp" in suffixes and suffixes["mlp"] != suffix: raise ValueError( f"Multiple mlp suffixes found: {suffixes['mlp']} and {suffix}, multiple mlp hook locations are not supported yet." ) suffixes["mlp"] = suffix else: - raise ValueError(f"Unknown node type {ll_node.name}") + raise ValueError(f"Unknown node type {ll_node_name}") return suffixes @classmethod - def make_corr_from_dict(cls, d: dict, suffixes=None, make_suffixes_from_corr=False): + def make_corr_from_dict( + cls, + d: dict[str, list[tuple[str, TorchIndex, Any]]], + suffixes: Optional[dict[str, str]] = None, + make_suffixes_from_corr: bool = False + ) -> "Correspondence": if make_suffixes_from_corr: suffixes = Correspondence.get_hook_suffix(d) - return cls( - { - HLNode(k, -1): {LLNode(name=node_name, index=None) for node_name in v} - for k, v in d.items() - }, - suffixes=suffixes, - ) + + input_dict = { + HLNode(k, -1): {LLNode(name=node_name, index=index) for node_name, index, _ in v} + for k, v in d.items() + } + if suffixes is not None: + return cls(input_dict, suffixes=suffixes) + else: + return cls(input_dict) - def save(self, filename: str): + def save(self, filename: str) -> None: pickle.dump(self, open(filename, "wb")) diff --git a/iit/utils/eval_ablations.py b/iit/utils/eval_ablations.py index d353647..95ba969 100644 --- a/iit/utils/eval_ablations.py +++ b/iit/utils/eval_ablations.py @@ -1,3 +1,4 @@ +from typing import Callable, Optional import os from enum import Enum from typing import Dict, List, Literal @@ -5,14 +6,14 @@ import dataframe_image as dfi import pandas as pd import torch as t -from torch.utils.data import Dataset +from torch import Tensor from tqdm import tqdm from transformer_lens import HookedTransformer from transformer_lens.HookedTransformer import HookPoint import iit.utils.index as index from iit.model_pairs.base_model_pair import BaseModelPair -from iit.model_pairs.iit_model_pair import IITModelPair +from iit.utils.eval_datasets import IITUniqueDataset from iit.utils.nodes import LLNode from iit.utils.eval_metrics import kl_div from iit.utils.iit_dataset import IITDataset @@ -32,18 +33,18 @@ class Categorical_Metric(Enum): def do_intervention( model_pair: BaseModelPair, - base_input: t.Tensor, - ablation_input: t.Tensor, + base_input: Tensor, + ablation_input: Tensor, node: LLNode, - hook_fn: callable, -): + hook_fn: Callable[[Tensor, HookPoint], Tensor], +) -> Tensor: """ Runs the model with ablation_input and caches the result. Then runs the model with base_input and with hook_fn applied to the node. Args: model_pair: BaseModelPair - base_input: t.Tensor - ablation_input: t.Tensor + base_input: Tensor + ablation_input: Tensor node: LLNode hook_fn: callable """ @@ -55,16 +56,17 @@ def do_intervention( return out +# TODO: change name to reflect that it's not just for resampling def resample_ablate_node( - model_pair: IITModelPair, - base_in: tuple[t.Tensor, t.Tensor, t.Tensor], - ablation_in: tuple[t.Tensor, t.Tensor, t.Tensor], + model_pair: BaseModelPair, + base_in: tuple[Tensor, Tensor, Tensor], + ablation_in: tuple[Tensor, Tensor, Tensor], node: LLNode, - hook_fn: callable, - atol=5e-2, - verbose=False, + hook_fn: Callable[[Tensor, HookPoint], Tensor], + atol: float = 5e-2, + verbose: bool = False, categorical_metric: Categorical_Metric = Categorical_Metric.ACCURACY, -): # TODO: change name to reflect that it's not just for resampling +) -> float: base_x, base_y = base_in[0:2] ablation_x, ablation_y = ablation_in[0:2] ll_out = do_intervention(model_pair, base_x, ablation_x, node, hook_fn) @@ -181,9 +183,9 @@ def check_causal_effect( batch_size: int = 256, node_type: Literal["a", "c", "n", "individual_c"] = "a", categorical_metric: Categorical_Metric = Categorical_Metric.ACCURACY, - hook_maker: callable = None, + hook_maker: Optional[Callable] = None, verbose: bool = False, -): +) -> dict[LLNode, float]: assert node_type in [ "a", "c", @@ -213,7 +215,7 @@ def check_causal_effect( hook_fns[node] = hook_maker(node) else: hook_fns[node] = model_pair.make_ll_ablation_hook(node) - results[node] = 0 + results[node] = 0. loader = dataset.make_loader(batch_size=batch_size, num_workers=0) for base_in, ablation_in in tqdm(loader): @@ -231,7 +233,11 @@ def check_causal_effect( return results -def get_mean_cache(model, dataset: Dataset, batch_size=8): +def get_mean_cache( + model: BaseModelPair | HookedTransformer, + dataset: IITDataset, + batch_size: int = 8 + ) -> dict[str, Tensor]: loader = dataset.make_loader(batch_size=batch_size, num_workers=0) mean_cache = {} for batch in tqdm(loader): @@ -253,18 +259,21 @@ def get_mean_cache(model, dataset: Dataset, batch_size=8): def make_ablation_hook( node: LLNode, - mean_cache: dict[str, t.Tensor], + mean_cache: Optional[dict[str, Tensor]] = None, use_mean_cache: bool = True, -) -> callable: +) -> Callable[[Tensor, HookPoint], Tensor]: if node.subspace is not None: raise NotImplementedError("Subspace not supported yet.") - def zero_hook(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor: + def zero_hook(hook_point_out: Tensor, hook: HookPoint) -> Tensor: hook_point_out[node.index.as_index] = 0 return hook_point_out - def mean_hook(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor: - cached_tensor = mean_cache[node.name] + def mean_hook(hook_point_out: Tensor, hook: HookPoint) -> Tensor: + if isinstance(mean_cache, dict): + cached_tensor = mean_cache[node.name] + else: + raise ValueError("mean_cache must be a dict when use_mean_cache is True") hook_point_out[node.index.as_index] = cached_tensor[node.index.as_index] return hook_point_out @@ -274,13 +283,12 @@ def mean_hook(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor: def ablate_nodes( - model_pair: IITModelPair, - base_input: tuple[t.Tensor, t.Tensor, t.Tensor], - fwd_hooks: List[tuple[str, callable]], - atol=5e-2, - relative_change=True, - verbose=False, -): + model_pair: BaseModelPair, + base_input: tuple[Tensor, Tensor, Tensor], + fwd_hooks: List[tuple[str, Callable]], + atol: float = 5e-2, + relative_change: bool = True +) -> Tensor: """ Returns 1 - accuracy of the model after ablating the nodes in fwd_hooks. Args: @@ -333,13 +341,13 @@ def ablate_nodes( def get_causal_effects_for_all_nodes( - model_pair, - uni_test_set, - batch_size=256, - use_mean_cache=True, - categorical_metric=Categorical_Metric.ACCURACY, - individual_nodes=True, -): + model_pair: BaseModelPair, + uni_test_set: IITUniqueDataset, + batch_size: int = 256, + use_mean_cache: bool = True, + categorical_metric: Categorical_Metric = Categorical_Metric.ACCURACY, + individual_nodes: bool = True, +) -> tuple[dict[LLNode, float], dict[LLNode, float]]: mean_cache = None if use_mean_cache: mean_cache = get_mean_cache(model_pair, uni_test_set, batch_size=batch_size) @@ -347,17 +355,13 @@ def get_causal_effects_for_all_nodes( model_pair, uni_test_set, node_type="n", - verbose=False, mean_cache=mean_cache, - categorical_metric=categorical_metric, ) za_result_in_circuit = check_causal_effect_on_ablation( model_pair, uni_test_set, node_type="c" if not individual_nodes else "individual_c", - verbose=False, mean_cache=mean_cache, - categorical_metric=categorical_metric, ) return za_result_not_in_circuit, za_result_in_circuit @@ -367,10 +371,8 @@ def check_causal_effect_on_ablation( dataset: IITDataset, batch_size: int = 256, node_type: str = "a", - mean_cache: dict[str, t.Tensor] = None, - categorical_metric: Categorical_Metric = Categorical_Metric.ACCURACY, - verbose: bool = False, -): + mean_cache: Optional[dict[str, Tensor]] = None, +) -> dict[LLNode, float]: use_mean_cache = True if mean_cache else False assert node_type in [ "a", @@ -398,20 +400,23 @@ def check_causal_effect_on_ablation( for node in all_nodes: hookers[node] = make_ablation_hook(node, mean_cache, use_mean_cache) - results[node] = 0 + results[node] = 0. loader = dataset.make_loader(batch_size=batch_size, num_workers=0) for batch in tqdm(loader): for node, hooker in hookers.items(): - results[node] += ablate_nodes(model_pair, batch, [(node.name, hooker)]) + results[node] += ablate_nodes(model_pair, batch, [(node.name, hooker)]).item() for node, result in results.items(): results[node] = result / len(loader) return results -def make_dataframe_of_results(result_not_in_circuit, result_in_circuit): - def create_name(node): +def make_dataframe_of_results( + result_not_in_circuit: dict[LLNode, float], + result_in_circuit: dict[LLNode, float] + ) -> pd.DataFrame: + def create_name(node: LLNode) -> str: if "mlp" in node.name: return node.name if node.index is not None and node.index != index.Ix[[None]]: @@ -434,12 +439,12 @@ def create_name(node): def make_combined_dataframe_of_results( - result_not_in_circuit, - result_in_circuit, - za_result_not_in_circuit, - za_result_in_circuit, + result_not_in_circuit: dict[LLNode, float], + result_in_circuit: dict[LLNode, float], + za_result_not_in_circuit: dict[LLNode, float], + za_result_in_circuit: dict[LLNode, float], use_mean_cache: bool = False, -): +) -> pd.DataFrame: df = make_dataframe_of_results(result_not_in_circuit, result_in_circuit) df2 = make_dataframe_of_results(za_result_not_in_circuit, za_result_in_circuit) df2_causal_effect = df2.pop("causal effect") @@ -457,12 +462,11 @@ def get_circuit_score( model_pair: BaseModelPair, dataset: IITDataset, nodes_to_ablate: List[LLNode], - mean_cache: Dict[str, t.Tensor] = None, + mean_cache: Optional[Dict[str, Tensor]] = None, batch_size: int = 256, use_mean_cache: bool = False, relative_change: bool = True, - verbose: bool = False, -): +) -> float: """ Returns the accuracy of the model after ablating the nodes in nodes_to_ablate. Defaults to zero ablation. @@ -485,15 +489,17 @@ def get_circuit_score( model_pair, base_input, fwd_hooks, - verbose=verbose, relative_change=relative_change, ) return result / len(loader) def save_result( - df: pd.DataFrame, save_dir: str, model_pair: BaseModelPair = None, suffix="" -): + df: pd.DataFrame, + save_dir: str, + model_pair: Optional[BaseModelPair] = None, + suffix: str = "" +) -> None: os.makedirs(save_dir, exist_ok=True) try: dfi.export(df, f"{save_dir}/results{suffix}.png") diff --git a/iit/utils/eval_datasets.py b/iit/utils/eval_datasets.py index 5067220..fd79332 100644 --- a/iit/utils/eval_datasets.py +++ b/iit/utils/eval_datasets.py @@ -1,20 +1,24 @@ -import numpy as np -from torch.utils.data import Dataset, DataLoader +from typing import Any, cast, Sized + + +import torch as t +from torch import Tensor +from torch.utils.data import Dataset + from iit.utils.config import DEVICE -import torch from iit.utils.iit_dataset import IITDataset class IITUniqueDataset(IITDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, base_data: Dataset, ablation_data: Dataset, seed: int = 0, every_combination: bool = False, device: t.device = DEVICE) -> None: + super().__init__(base_data, ablation_data, seed, every_combination, device) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Any: return self.base_data[index] - def __len__(self): - return len(self.base_data) - + def __len__(self) -> int: + return len(cast(Sized, self.base_data)) + @staticmethod - def collate_fn(batch, device=DEVICE): + def collate_fn(batch: tuple, device: t.device = DEVICE) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]: #type: ignore return IITDataset.get_encoded_input_from_torch_input(batch, device) diff --git a/iit/utils/eval_metrics.py b/iit/utils/eval_metrics.py index cc6b013..92e2266 100644 --- a/iit/utils/eval_metrics.py +++ b/iit/utils/eval_metrics.py @@ -1,39 +1,45 @@ -import torch +from typing import Callable + +import torch as t +from torch import Tensor import iit.utils.index as index -def kl_div(a: torch.Tensor, - b: torch.Tensor, - label_idx: index.TorchIndex): +def kl_div( + a: Tensor, + b: Tensor, + label_idx: index.TorchIndex + ) -> Tensor: a_pmf = a[label_idx.as_index] b_pmf = b[label_idx.as_index] # check if b is ints - if b_pmf.dtype in [torch.int32, torch.int64, torch.long, torch.int]: + if b_pmf.dtype in [t.int32, t.int64, t.long, t.int]: if b.shape == a.shape[:-1]: - b_pmf = torch.nn.functional.one_hot(b_pmf, num_classes=a_pmf.shape[-1]).float() + b_pmf = t.nn.functional.one_hot(b_pmf, num_classes=a_pmf.shape[-1]).float() b_pmf = b_pmf.float() - pmf_checker = lambda x: torch.allclose( - x.sum(dim=-1), torch.ones_like(x.sum(dim=-1)) + pmf_checker: Callable[[Tensor], bool] = lambda x: t.allclose( + x.sum(dim=-1), t.ones_like(x.sum(dim=-1)) ) if not pmf_checker(a_pmf): - a_pmf = torch.nn.functional.log_softmax(a_pmf, dim=-1) + a_pmf = t.nn.functional.log_softmax(a_pmf, dim=-1) else: - a_pmf = torch.log(a_pmf) + a_pmf = t.log(a_pmf) if not pmf_checker(b_pmf): - b_pmf = torch.nn.functional.softmax(b_pmf, dim=-1) + b_pmf = t.nn.functional.softmax(b_pmf, dim=-1) - return torch.nn.functional.kl_div( + return t.nn.functional.kl_div( a_pmf, b_pmf, reduction="none", log_target=False ).sum(dim=-1) def accuracy_affected( - a: torch.Tensor, - b: torch.Tensor, - label_unchanged: torch.Tensor, - label_idx: index.TorchIndex): - a_lab = torch.argmax(a[label_idx.as_index], dim=-1) - b_lab = torch.argmax(b[label_idx.as_index], dim=-1) + a: Tensor, + b: Tensor, + label_unchanged: Tensor, + label_idx: index.TorchIndex + ) -> Tensor: + a_lab = t.argmax(a[label_idx.as_index], dim=-1) + b_lab = t.argmax(b[label_idx.as_index], dim=-1) - out_unchanged = torch.eq(a_lab, b_lab) + out_unchanged = t.eq(a_lab, b_lab) changed_result = (~out_unchanged).cpu().float() * (~label_unchanged).cpu().float() return changed_result.sum() / (~label_unchanged).sum() diff --git a/iit/utils/eval_scripts.py b/iit/utils/eval_scripts.py index 9e5eb44..fe0e005 100644 --- a/iit/utils/eval_scripts.py +++ b/iit/utils/eval_scripts.py @@ -1,4 +1,10 @@ -import transformer_lens +import os +import json +import argparse + +import torch as t +import numpy as np +from transformer_lens import HookedTransformer from iit.model_pairs.ioi_model_pair import IOI_ModelPair from iit.tasks.ioi import ( @@ -8,15 +14,20 @@ make_corr_dict, suffixes, ) -from iit.utils.eval_ablations import * -import numpy as np +from iit.utils.eval_ablations import ( + check_causal_effect, + get_causal_effects_for_all_nodes, + make_combined_dataframe_of_results, + save_result, +) +from iit.utils.correspondence import Correspondence from iit.utils.iit_dataset import IITDataset from iit.utils.eval_datasets import IITUniqueDataset -import json from iit.utils.io_scripts import load_files_from_wandb +from iit.utils.argparsing import IOIArgParseNamespace -def eval_ioi(args): +def eval_ioi(args: IOIArgParseNamespace) -> None: weights = args.weights use_mean_cache = args.mean device = args.device @@ -27,10 +38,10 @@ def eval_ioi(args): batch_size = args.batch_size num_samples = args.num_samples # load model - ll_cfg = transformer_lens.HookedTransformer.from_pretrained("gpt2").cfg.to_dict() + ll_cfg = HookedTransformer.from_pretrained("gpt2").cfg.to_dict() ll_cfg.update(ioi_cfg) - ll_model = transformer_lens.HookedTransformer(ll_cfg).to(device) + ll_model = HookedTransformer(ll_cfg).to(device) if args.load_from_wandb: load_files_from_wandb( "ioi", @@ -42,7 +53,7 @@ def eval_ioi(args): ) try: ll_model.load_state_dict( - torch.load(f"{save_dir}/ll_model_{weights}.pth", map_location=device) + t.load(f"{save_dir}/ll_model_{weights}.pth", map_location=device) ) except FileNotFoundError: raise FileNotFoundError(f"Model not found at {save_dir}") @@ -97,7 +108,6 @@ def eval_ioi(args): za_result_in_circuit, use_mean_cache=use_mean_cache, ) - suffix = f"_{args.categorical_metric}" save_result(df, results_dir) with open(f"{results_dir}/metric_collection.log", "w") as f: f.write(str(metric_collection)) diff --git a/iit/utils/iit_dataset.py b/iit/utils/iit_dataset.py index 1acc40d..85b5b48 100644 --- a/iit/utils/iit_dataset.py +++ b/iit/utils/iit_dataset.py @@ -1,10 +1,13 @@ # import everything relevant +from typing import Optional, cast, Sized, Callable import numpy as np from torch.utils.data import Dataset from iit.utils.config import DEVICE from torch.utils.data import DataLoader -import torch +import torch as t +from torch import Tensor +dataset_len: Callable[[Dataset], int] = lambda dataset: len(cast(Sized, dataset)) class IITDataset(Dataset): """ @@ -12,7 +15,12 @@ class IITDataset(Dataset): """ def __init__( - self, base_data, ablation_data, seed=0, every_combination=False, device=DEVICE + self, + base_data: Dataset, + ablation_data: Dataset, + seed: int = 0, + every_combination: bool = False, + device: t.device = DEVICE ): # For vanilla IIT, base_data and ablation_data are the same self.base_data = base_data @@ -21,58 +29,66 @@ def __init__( self.every_combination = every_combination self.device = device - def __getitem__(self, index): + def __getitem__(self, index: int) -> tuple: if self.every_combination: - base_index = index // len(self.ablation_data) - ablation_index = index % len(self.ablation_data) + base_index = index // dataset_len(self.ablation_data) + ablation_index = index % dataset_len(self.ablation_data) base_input = self.base_data[base_index] ablation_input = self.ablation_data[ablation_index] return base_input, ablation_input # sample based on seed rng = np.random.default_rng(self.seed * 1000000 + index) - base_index = rng.choice(len(self.base_data)) - ablation_index = rng.choice(len(self.ablation_data)) + base_index = rng.choice(dataset_len(self.base_data)) + ablation_index = rng.choice(dataset_len(self.ablation_data)) base_input = self.base_data[base_index] ablation_input = self.ablation_data[ablation_index] return base_input, ablation_input - def __len__(self): + def __len__(self) -> int: if self.every_combination: - return len(self.base_data) * len(self.ablation_data) - return len(self.base_data) + return dataset_len(self.base_data) * dataset_len(self.ablation_data) + return dataset_len(self.base_data) @staticmethod - def get_encoded_input_from_torch_input(xy, device=DEVICE): + def get_encoded_input_from_torch_input( + xy: tuple, + device: t.device = DEVICE + ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]: zipped_data = tuple(zip(*xy)) - x, y = zipped_data[0:2] - x = torch.stack([x_i.to(device) for x_i in x]) - y = torch.stack([y_i.to(device) for y_i in y]) + x_in, y_in = zipped_data[0:2] + x = t.stack([x_i.to(device) for x_i in x_in]) + y = t.stack([y_i.to(device) for y_i in y_in]) if len(zipped_data) == 3: - int_vars = zipped_data[2] - int_vars = torch.stack([iv.to(device) for iv in int_vars]) + int_vars_in = zipped_data[2] + int_vars = t.stack([iv.to(device) for iv in int_vars_in]) return x, y, int_vars else: return x, y @staticmethod - def collate_fn(batch, device=DEVICE): + def collate_fn( + batch: list[Tensor] | Tensor, + device: t.device = DEVICE + ) -> tuple[tuple, tuple]: if not isinstance(batch, list): # if batch is a single element, because batch_size was 1 or None, it is a tuple instead of a list - batch = [batch] - - base_input, ablation_input = zip(*batch) + batch_list = [batch] + else: + batch_list = batch + + base_input_list, ablation_input_list = zip(*batch_list) return IITDataset.get_encoded_input_from_torch_input( - base_input, device - ), IITDataset.get_encoded_input_from_torch_input(ablation_input, device) + base_input_list, device + ), IITDataset.get_encoded_input_from_torch_input(ablation_input_list, device) def make_loader( self, - batch_size, - num_workers, + batch_size: int, + num_workers: int, ) -> DataLoader: return DataLoader( self, @@ -81,15 +97,21 @@ def make_loader( num_workers=num_workers, collate_fn=lambda x: self.collate_fn(x, self.device), ) + + def get_input_shape(self) -> t.Size: + return self.base_data.get_input_shape() # type: ignore - -def train_test_split(dataset, test_size=0.2, random_state=None): - n = len(dataset) +def train_test_split( + dataset: Dataset, + test_size: float = 0.2, + random_state: Optional[int] = None + ) -> list[t.utils.data.Subset]: + n = dataset_len(dataset) split = int(n * test_size) if random_state is None: - return torch.utils.data.random_split(dataset, [n - split, split]) - return torch.utils.data.random_split( + return t.utils.data.random_split(dataset, [n - split, split]) + return t.utils.data.random_split( dataset, [n - split, split], - generator=torch.Generator().manual_seed(random_state), + generator=t.Generator().manual_seed(random_state), ) diff --git a/iit/utils/index.py b/iit/utils/index.py index 28348b4..865d040 100644 --- a/iit/utils/index.py +++ b/iit/utils/index.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Optional class TorchIndex: @@ -35,11 +35,14 @@ def __init__( ) ) - def __hash__(self): + def __hash__(self) -> int: return hash(self.hashable_tuple) - def __eq__(self, other): - return self.hashable_tuple == other.hashable_tuple + def __eq__(self, other: object) -> bool: + if isinstance(other, TorchIndex): + return self.hashable_tuple == other.hashable_tuple + else: + return False def __repr__(self) -> str: ret = "[" @@ -63,10 +66,10 @@ def __repr__(self) -> str: ret += "]" return ret - def graphviz_index(self, use_actual_colon=True) -> str: - return self.__repr__(use_actual_colon=use_actual_colon) + def graphviz_index(self) -> str: + return self.__repr__() - def intersects(self, other) -> bool: + def intersects(self, other: Optional["TorchIndex"]) -> bool: if other is None or self == Ix[[None]] or other == Ix[[None]]: return True # None means all indices if len(self.as_index) != len(other.as_index): @@ -104,8 +107,7 @@ def intersects(self, other) -> bool: class Index: """A class purely for syntactic sugar, that returns the index it's indexed with""" - def __getitem__(self, index): + def __getitem__(self, index: Iterable) -> TorchIndex: return TorchIndex(index) - Ix = Index() diff --git a/iit/utils/io_scripts.py b/iit/utils/io_scripts.py index 6658ab3..55b9e8d 100644 --- a/iit/utils/io_scripts.py +++ b/iit/utils/io_scripts.py @@ -1,11 +1,15 @@ import os import json -import torch + +import torch as t import wandb + from iit.tasks.task_loader import get_default_corr +from iit.model_pairs.base_model_pair import BaseModelPair +from iit.utils.argparsing import IOIArgParseNamespace -def save_model(model_pair, args, task): +def save_model(model_pair: BaseModelPair, args: IOIArgParseNamespace, task: str) -> None: """ Folder structure: -ll_models @@ -29,7 +33,7 @@ def save_model(model_pair, args, task): os.makedirs(results_dir, exist_ok=True) # save model - torch.save(ll_model.state_dict(), f"{save_dir}/ll_model_{model_suffix}.pth") + t.save(ll_model.state_dict(), f"{save_dir}/ll_model_{model_suffix}.pth") # save training args training_args_file = os.path.join(results_dir, "training_args.json") @@ -47,7 +51,7 @@ def save_model(model_pair, args, task): with open(metrics_file, "w") as f: f.write(f"Epochs: {epochs}\n") early_stop_condition = model_pair._check_early_stop_condition( - model_pair.test_metrics.metrics + model_pair.test_metrics ) f.write(f"Early stop: {early_stop_condition}\n") f.write("\n\n--------------------------------\n\n") @@ -74,8 +78,13 @@ def save_model(model_pair, args, task): def load_files_from_wandb( - task, weights, next_token, files_to_download, base_path, include_mlp=True -): + task: str, + weights: str, + next_token: bool, + files_to_download: list[str], + base_path: str, + include_mlp: bool = True +) -> None: api = wandb.Api() runs = api.runs("iit_models") next_token_str = "_next_token" if next_token else "" diff --git a/iit/utils/logger.py b/iit/utils/logger.py index f497d12..f9d14b1 100644 --- a/iit/utils/logger.py +++ b/iit/utils/logger.py @@ -1,11 +1,14 @@ +import sys import os -import torch -import numpy as np import time -np.set_printoptions(threshold=np.inf) + +import numpy as np +import torch as t +from torch import Tensor +np.set_printoptions(threshold=sys.maxsize) class LoggingDict(dict): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): #type: ignore dirname = "logs" if not os.path.exists(dirname): os.makedirs(dirname) @@ -13,25 +16,25 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def compare(self, x, y): - if isinstance(x, (torch.Tensor)): - assert isinstance(y, (torch.Tensor)), "x and y are not the same type" - return (x == y).all() + def compare(self, x: t.Any, y: t.Any) -> bool: + if isinstance(x, (Tensor)): + assert isinstance(y, (Tensor)), "x and y are not the same type" + return bool((x == y).all()) elif isinstance(x, (np.ndarray)): assert isinstance(y, (np.ndarray)), "x and y are not the same type" - return (x == y).all() + return bool((x == y).all()) elif isinstance(x, (list)): assert isinstance(y, (list)), "x and y are not the same type" - return all(self.compare(x[i], y[i]) for i in range(len(x))) + return bool(all(self.compare(x[i], y[i]) for i in range(len(x)))) else: return x == y - def convert_tensor_to_numpy(self, x): - if isinstance(x, (torch.Tensor)): + def convert_tensor_to_numpy(self, x: Tensor | np.ndarray) -> np.ndarray: + if isinstance(x, (Tensor)): return x.cpu().detach().numpy() return x - def __setitem__(self, key, value): + def __setitem__(self, key: t.Any, value: t.Any) -> None: if key not in self: with open(self._log_filename, "a") as f: f.write(f"{key}\n initial value: {value}\n") @@ -42,7 +45,7 @@ def __setitem__(self, key, value): if __name__ == "__main__": - logger = LoggingDict() + logger = LoggingDict() # type: ignore logger["a"] = 1 logger["b"] = 2 logger["c"] = 3 diff --git a/iit/utils/metric.py b/iit/utils/metric.py index 1184e59..54e5993 100644 --- a/iit/utils/metric.py +++ b/iit/utils/metric.py @@ -1,3 +1,5 @@ +from typing import Iterator + from enum import Enum import numpy as np @@ -9,16 +11,16 @@ class MetricType(Enum): class MetricStore: - def __init__(self, name, metric_type: MetricType): + def __init__(self, name: str, metric_type: MetricType): self._name = name self.type = metric_type - self._store = [] + self._store: list = [] assert self.type in MetricType, f"Invalid metric type {self.type}" - def append(self, metric): + def append(self, metric: float | list) -> None: self._store.append(metric) - def get_value(self): + def get_value(self) -> None | float | list | np.ndarray: if len(self._store) == 0: return None if self.type == MetricType.ACCURACY: @@ -26,17 +28,21 @@ def get_value(self): else: return np.mean(self._store) - def get_name(self): + def get_name(self) -> str: return self._name def __str__(self) -> str: if self.get_value() is None: return f"{self._name}: None" if self.type == MetricType.ACCURACY: - return f"{self._name}: {float(self.get_value()):.2f}%" + val = self.get_value() + if isinstance(val, type(None)): + return f"{self._name}: None" + else: + return f"{self._name}: {val:.2f}%" return f"{self._name}: {self.get_value():.4f}" - def __len__(self): + def __len__(self) -> int: return len(self._store) def __repr__(self) -> str: @@ -44,14 +50,11 @@ def __repr__(self) -> str: class PerTokenMetricStore(MetricStore): - def __init__(self, name, **kwargs): + def __init__(self, name: str, precision: int = 3): super().__init__(name, metric_type=MetricType.LOG) - if "precision" in kwargs: - np.set_printoptions(precision=kwargs["precision"]) - else: - np.set_printoptions(precision=3) + np.set_printoptions(precision=precision) - def get_value(self): + def get_value(self) -> None | float: if len(self._store) == 0: return None return np.mean(self._store, axis=0) @@ -64,7 +67,7 @@ class MetricStoreCollection: def __init__(self, list_of_metric_stores: list[MetricStore]): self.metrics = list_of_metric_stores - def update(self, metrics: dict[str, float]): + def update(self, metrics: dict[str, float | list]) -> None: for k, v in metrics.items(): key_found = False for metric in self.metrics: @@ -78,7 +81,7 @@ def update(self, metrics: dict[str, float]): np.unique(lengths).shape[0] == 1 ), f"All metric stores should have the same length after update!, got lengths: {lengths}" - def create_metric_store(self, name, metric_type: MetricType): + def create_metric_store(self, name: str, metric_type: MetricType) -> MetricStore: assert [ len(metric._store) == 0 for metric in self.metrics ], "All metric stores should be empty before creating a new one!" @@ -92,5 +95,8 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def to_dict(self): + def to_dict(self) -> dict: return {metric.get_name(): metric.get_value() for metric in self.metrics} + + def __iter__(self) -> Iterator[MetricStore]: + return iter(self.metrics) diff --git a/iit/utils/node_picker.py b/iit/utils/node_picker.py index 3537c74..9449fc0 100644 --- a/iit/utils/node_picker.py +++ b/iit/utils/node_picker.py @@ -71,7 +71,7 @@ def get_post_nodes_not_in_circuit( print("WARNING: This doesn't work when switching individual heads on/off.") suffixes = hl_ll_corr.get_suffixes() nodes_not_in_circuit = get_nodes_not_in_circuit(ll_model, hl_ll_corr) - post_nodes_not_in_circuit = [] + post_nodes_not_in_circuit: list[LLNode] = [] for node in nodes_not_in_circuit: layer = int(node.name.split(".")[1]) if "attn" in node.name and "attn" in suffixes: @@ -179,7 +179,7 @@ def get_params_not_in_circuit( params_not_in_circuit.append(param) return params_not_in_circuit -def find_ll_node_by_name(name, list_of_nodes: list[LLNode]): +def find_ll_node_by_name(name: str, list_of_nodes: list[LLNode]) -> list[LLNode]: ll_nodes = [] for node in list_of_nodes: if node.name == name: diff --git a/iit/utils/nodes.py b/iit/utils/nodes.py index faddf04..cf94617 100644 --- a/iit/utils/nodes.py +++ b/iit/utils/nodes.py @@ -11,16 +11,12 @@ class HLNode: name: HookName num_classes: int - index: Optional[TorchIndex] = Ix[[None]] - - def __post_init__(self): - if self.index is None: - self.index = Ix[[None]] + index: TorchIndex = Ix[[None]] def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, HLNode): return self.name == other.name elif isinstance(other, str): @@ -37,20 +33,18 @@ def __repr__(self) -> str: @dataclass class LLNode: name: HookName - index: TorchIndex + index: TorchIndex = Ix[[None]] subspace: Optional[t.Tensor] = None - def __post_init__(self): - if self.index is None: - self.index = Ix[[None]] - - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, LLNode) and dataclasses.astuple( self ) == dataclasses.astuple(other) - def __hash__(self): + def __hash__(self) -> int: return hash(dataclasses.astuple(self)) - def get_index(self): + def get_index(self) -> tuple[slice]: + if self.index is None: + raise ValueError("Index is None, which should not happen after __post_init__. Perhaps you set it to None manually?") return self.index.as_index \ No newline at end of file diff --git a/iit/utils/plotter.py b/iit/utils/plotter.py index 002dc3a..f45a552 100644 --- a/iit/utils/plotter.py +++ b/iit/utils/plotter.py @@ -1,38 +1,34 @@ +from typing import Callable + import matplotlib.pyplot as plt import numpy as np import wandb -from PIL import Image from iit.model_pairs.base_model_pair import HLNode -def get_hookpoint_labels(hookpoints): +def get_hookpoint_labels(hookpoints: list[str]) -> list[str]: return [ i.replace("mod.", "").replace(".hook_point", "").replace(".", " ") for i in hookpoints ] -def get_leaky_hlnode_labels(hl_nodes): +def get_leaky_hlnode_labels(hl_nodes: list[str | HLNode]) -> list[str]: x_tick_string = """{} -> {}""" - if type(hl_nodes[0]) == HLNode: - hl_nodes = [i.name for i in hl_nodes] - elif type(hl_nodes[0]) != str: - raise ValueError( - f"hl_nodes must be a list of str or HLNode, got {type(hl_nodes[0])}" - ) - return [x_tick_string.format(i.split("_")[1], i.split("_")[-1]) for i in hl_nodes] + hl_nodes_pass = [node if isinstance(node, str) else node.name for node in hl_nodes] + return [x_tick_string.format(i.split("_")[1], i.split("_")[-1]) for i in hl_nodes_pass] def plot_probe_stats( - correctness_stats_per_layer, - leaky_stats_per_layer, - reduction="max", - prefix="", - use_wandb=False, -): + correctness_stats_per_layer: dict[str, dict], + leaky_stats_per_layer: dict[str, dict], + reduction: str = "max", + prefix: str = "", + use_wandb: bool = False, +) -> None: # make arrays hookpoints = list(correctness_stats_per_layer.keys()) - get_hl_nodes = lambda stats: list(stats[list(stats.keys())[0]]["probes"].keys()) + get_hl_nodes: Callable[[dict], list] = lambda stats: list(stats[list(stats.keys())[0]]["probes"].keys()) hl_nodes = get_hl_nodes(correctness_stats_per_layer) # correctness_loss = np.zeros((len(hookpoints), len(hl_nodes))) correctness_acc = np.zeros((len(hookpoints), len(hl_nodes))) @@ -46,8 +42,8 @@ def plot_probe_stats( leaky_acc = np.zeros((len(hookpoints), len(hl_nodes))) leaky_hl_nodes = get_hl_nodes(leaky_stats_per_layer) leaky_accs_all = np.zeros((len(hookpoints), len(leaky_hl_nodes))) - get_idx = lambda name: hl_nodes.index("hook_{}".format(name)) - reduction = ( + get_idx: Callable[[str], int] = lambda name: hl_nodes.index("hook_{}".format(name)) + reduction_function = ( np.mean if reduction == "mean" else ( @@ -57,7 +53,7 @@ def plot_probe_stats( ) ) assert ( - reduction is not None + reduction_function is not None ), f"reduction must be one of 'mean', 'max', or 'median', got {reduction}" for i, hookpoint in enumerate(hookpoints): accs = np.zeros((len(hl_nodes), len(hl_nodes))) @@ -69,7 +65,7 @@ def plot_probe_stats( accs[get_idx(leaked_from), get_idx(leaked_to)] = acc leaky_accs_all[i, j] = acc # print(f"accs: {accs}") - accs_reduced = reduction(accs, axis=0) # rows = leaked_from; cols = leaked_to + accs_reduced = reduction_function(accs, axis=0) # rows = leaked_from; cols = leaked_to for j, _ in enumerate(hl_nodes): leaky_acc[i, j] = accs_reduced[j] # print(f"leaky_acc: {leaky_acc}") @@ -77,6 +73,7 @@ def plot_probe_stats( # plot # TODO: maybe add this: https://stackoverflow.com/questions/9707676/defining-a-discrete-colormap-for-imshow fig, ax = plt.subplots(1, 2, figsize=(20, 10)) + assert isinstance(ax, np.ndarray), "ax must be a numpy array to be indexed" # print(f"correctness_acc: {correctness_acc}") # print(f"leaky_acc: {leaky_acc}") @@ -124,7 +121,7 @@ def plot_probe_stats( print("Plotted probe stats. Find them in plots folder.") -def plot_ablation_stats(stats_per_layer, prefix="", use_wandb=False): +def plot_ablation_stats(stats_per_layer: dict[str, dict], prefix: str = "", use_wandb: bool = False) -> None: # make arrays hookpoints = list(stats_per_layer.keys()) hl_nodes = list(stats_per_layer[hookpoints[0]].keys()) diff --git a/iit/utils/probes.py b/iit/utils/probes.py index d63f311..90458d2 100644 --- a/iit/utils/probes.py +++ b/iit/utils/probes.py @@ -1,22 +1,26 @@ -import torch +from typing import Callable + +import torch as t import torch.nn as nn -from iit.model_pairs.base_model_pair import HLNode, LLNode, BaseModelPair +from torch import Tensor from tqdm import tqdm + +from iit.model_pairs.base_model_pair import HLNode, LLNode, BaseModelPair from iit.utils.config import DEVICE def construct_probe( high_level_node: HLNode, ll_nodes: set[LLNode], - dummy_cache: dict[str, torch.Tensor], - bias=False, -): + dummy_cache: dict[str, t.Tensor], + bias: bool = False, +) -> nn.Linear: """ Makes a probe for a given high-level node, given the low-level model and nodes. """ if len(ll_nodes) > 1: raise NotImplementedError # raising as unsure about summing over multiple nodes - _get_hook_out_size = ( + _get_hook_out_size: Callable[[dict[str, t.Tensor], LLNode], int] = ( lambda dummy_cache, ll_node: dummy_cache[ll_node.name][ll_node.index.as_index] .flatten() .shape[0] @@ -27,48 +31,48 @@ def construct_probe( return nn.Linear(size, high_level_node.num_classes, bias=bias).to(DEVICE) -def construct_probes(model_pair: BaseModelPair, input_shape: tuple[int], bias=False): +def construct_probes(model_pair: BaseModelPair, input_shape: t.Size, bias: bool = False) -> dict[HLNode, nn.Linear]: probes = {} _, dummy_cache = model_pair.ll_model.run_with_cache( - torch.zeros(input_shape).to(DEVICE) + t.zeros(input_shape).to(DEVICE) ) for hl_node, ll_nodes in model_pair.corr.items(): probe = construct_probe(hl_node, ll_nodes, dummy_cache, bias=bias) - probes[hl_node.name] = probe + probes[hl_node] = probe return probes def train_probes_on_model_pair( model_pair: BaseModelPair, - input_shape: str, - train_set: torch.utils.data.Dataset, + input_shape: t.Size, + train_set: t.utils.data.Dataset, training_args: dict, -): +) -> dict[str, dict]: probes = construct_probes(model_pair, input_shape=input_shape) params = [] for p in probes.values(): p.train() params += list(p.parameters()) - probe_optimizer = torch.optim.Adam(params, lr=training_args["lr"]) + probe_optimizer = t.optim.Adam(params, lr=training_args["lr"]) criterion = nn.CrossEntropyLoss() - probe_losses = {k: [] for k in probes.keys()} - probe_accuracies = {k: [] for k in probes.keys()} - loader = torch.utils.data.DataLoader( + probe_losses: dict[HLNode, list[Tensor]] = {k: [] for k in probes.keys()} + probe_accuracies: dict[HLNode, list[Tensor]] = {k: [] for k in probes.keys()} + loader = t.utils.data.DataLoader( train_set, batch_size=training_args["batch_size"], shuffle=True, num_workers=training_args["num_workers"], ) for _ in tqdm(range(training_args["epochs"])): - probe_accuracy_run = {k: 0 for k in probes.keys()} - probe_loss_run = {k: 0 for k in probes.keys()} + probe_accuracy_run = {k: t.zeros(1) for k in probes.keys()} + probe_loss_run = {k: t.zeros(1) for k in probes.keys()} for x, y, int_vars in loader: probe_optimizer.zero_grad() x = x.to(DEVICE) y = y.to(DEVICE) out, cache = model_pair.ll_model.run_with_cache(x) - probe_loss = 0 + probe_loss = t.zeros(1) for hl_node_name, probe in probes.items(): ll_nodes = model_pair.corr[hl_node_name] gt = model_pair.hl_model.get_idx_to_intermediate(hl_node_name)( @@ -88,7 +92,7 @@ def train_probes_on_model_pair( probe_accuracy_run[hl_node_name] += ( (probe_out.argmax(1) == gt).float().mean().item() ) - probe_loss.backward() + probe_loss.backward() # type: ignore probe_optimizer.step() for k in probe_losses.keys(): probe_losses[k].append(probe_loss_run[k] / len(loader)) @@ -96,18 +100,23 @@ def train_probes_on_model_pair( return {"probes": probes, "loss": probe_losses, "accuracy": probe_accuracies} -def evaluate_probe(probes, model_pair, test_set, criterion): - probe_stats = {} +def evaluate_probe( + probes: dict[str, nn.Linear], + model_pair: BaseModelPair, + test_set: t.utils.data.Dataset, + criterion: Callable[[Tensor, Tensor], Tensor] + ) -> dict[str, dict]: + probe_stats: dict[str, dict] = {} probe_stats["test loss"] = {} probe_stats["test accuracy"] = {} for hl_node_name, probe in tqdm(probes.items(), desc="Evaluating probes"): probe.eval() - probe_loss = 0 - probe_accuracy = 0 - loader = torch.utils.data.DataLoader( + probe_loss = t.zeros(1) + probe_accuracy = t.zeros(1) + loader = t.utils.data.DataLoader( test_set, batch_size=256, shuffle=True, num_workers=0 ) - with torch.no_grad(): + with t.no_grad(): for x, y, int_vars in loader: x = x.to(DEVICE) y = y.to(DEVICE) diff --git a/iit/utils/train_scripts.py b/iit/utils/train_scripts.py index 0772dd8..5a55076 100644 --- a/iit/utils/train_scripts.py +++ b/iit/utils/train_scripts.py @@ -1,4 +1,4 @@ -import transformer_lens +from transformer_lens import HookedTransformer from iit.model_pairs.ioi_model_pair import IOI_ModelPair from iit.utils.iit_dataset import train_test_split @@ -12,13 +12,11 @@ suffixes ) from iit.utils.correspondence import Correspondence -from argparse import Namespace +from iit.utils.argparsing import IOIArgParseNamespace -def train_ioi( - args: Namespace, -): - device = args.device +def train_ioi(args: IOIArgParseNamespace) -> IOI_ModelPair: + device = t.device(args.device) num_samples = args.num_samples epochs = args.epochs use_wandb = args.use_wandb @@ -38,16 +36,16 @@ def train_ioi( t.manual_seed(0) np.random.seed(0) - ll_cfg = transformer_lens.HookedTransformer.from_pretrained( + ll_cfg = HookedTransformer.from_pretrained( "gpt2" ).cfg.to_dict() ll_cfg.update(ioi_cfg) ll_cfg["init_weights"] = True - ll_model = transformer_lens.HookedTransformer(ll_cfg).to(device) + ll_model = HookedTransformer(ll_cfg).to(device) print("making ioi dataset and hl") ioi_dataset, hl_model = make_ioi_dataset_and_hl( - num_samples, ll_model, NAMES, device=args.device, verbose=True + num_samples, ll_model, NAMES, device=device, verbose=True ) print("making IIT dataset") train_ioi_dataset, test_ioi_dataset = train_test_split( diff --git a/iit/utils/wrapper.py b/iit/utils/wrapper.py index d794b3a..e98f86e 100644 --- a/iit/utils/wrapper.py +++ b/iit/utils/wrapper.py @@ -1,3 +1,5 @@ +from typing import Callable + import torch as t from torch import Tensor from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -11,29 +13,30 @@ class HookedModuleWrapper(HookedRootModule): def __init__( self, mod: t.nn.Module, - name="model", - recursive=False, - hook_self=True, - top_level=True, - hook_pre=False, + name: str = "model", + recursive: bool = False, + get_hook_self: bool = True, + get_hook_pre: bool = False, ): super().__init__() self.mod = mod # deepcopy(mod) - self.hook_self = hook_self - self.hook_pre = hook_pre - if hook_pre: + if get_hook_pre: self.hook_pre = HookPoint() self.hook_pre.name = name + "pre" - if hook_self: + else: + self.hook_pre = None + if get_hook_self: hook_point = HookPoint() hook_point.name = name self.hook_point = hook_point + else: + self.hook_point = None if recursive: self.wrap_hookpoints_recursively() self.setup() - def wrap_hookpoints_recursively(self, verbose=False): - show = lambda *args: print(*args) if verbose else None + def wrap_hookpoints_recursively(self, verbose: bool = False) -> None: + show: Callable[[t.Any], None] = lambda *args: print(*args) if verbose else None for key, submod in list(self.mod._modules.items()): if isinstance(submod, HookedModuleWrapper): show(f"SKIPPING {key}:{type(submod)}") @@ -45,17 +48,18 @@ def wrap_hookpoints_recursively(self, verbose=False): show(f"INDIVIDUALLY WRAPPING {key}:{type(submod)}") for i, subsubmod in enumerate(submod): new_submod = HookedModuleWrapper( - subsubmod, name=f"{key}.{i}", recursive=True, top_level=False + subsubmod, name=f"{key}.{i}", recursive=True ) submod[i] = new_submod continue - # print(f'wrapping {key}:{type(submod)}') - new_submod = HookedModuleWrapper( - submod, name=key, recursive=True, top_level=False - ) - self.mod.__setattr__(key, new_submod) - def forward(self, *args, **kwargs): + if isinstance(submod, t.nn.Module): + new_submod = HookedModuleWrapper( + submod, name=key, recursive=True + ) + self.mod.__setattr__(key, new_submod) + + def forward(self, *args, **kwargs) -> Tensor: #type: ignore if self.hook_pre: result = self.mod.forward(self.hook_pre(*args, **kwargs)) else: @@ -65,6 +69,5 @@ def forward(self, *args, **kwargs): assert isinstance(result, Tensor) return self.hook_point(result) - -def get_hook_points(model: HookedRootModule): +def get_hook_points(model: HookedRootModule) -> list[str]: return [k for k in list(model.hook_dict.keys()) if "conv" in k] diff --git a/mypy-3.10.ini b/mypy-3.10.ini new file mode 100644 index 0000000..0fbbb8f --- /dev/null +++ b/mypy-3.10.ini @@ -0,0 +1,10 @@ +[mypy] +python_version = 3.10 +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +ignore_missing_imports = True +files = ./iit/,./*.py +exclude = ./plots|./tests + diff --git a/mypy-3.11.ini b/mypy-3.11.ini new file mode 100644 index 0000000..2b98351 --- /dev/null +++ b/mypy-3.11.ini @@ -0,0 +1,10 @@ +[mypy] +python_version = 3.11 +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +ignore_missing_imports = True +files = ./iit/,./*.py +exclude = ./plots|./tests + diff --git a/mypy-3.12.ini b/mypy-3.12.ini new file mode 100644 index 0000000..b000829 --- /dev/null +++ b/mypy-3.12.ini @@ -0,0 +1,10 @@ +[mypy] +python_version = 3.12 +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +ignore_missing_imports = True +files = ./iit/,./*.py +exclude = ./plots|./tests + diff --git a/pyproject.toml b/pyproject.toml index 7277c21..b1ae641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ pytest = "^8.2.2" pillow = "^10.2.0" torchvision = "^0.18.1" matplotlib = "^3.8.2" +mypy = "^1.11.0" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_metric_logger.py b/tests/test_metric_logger.py index 71485f0..46e4e49 100644 --- a/tests/test_metric_logger.py +++ b/tests/test_metric_logger.py @@ -2,6 +2,8 @@ from iit.model_pairs.iit_model_pair import IITModelPair from iit.model_pairs.ioi_model_pair import IOI_ModelPair +from .test_model_pairs import get_test_model_pair_ingredients + def test_metric_collection(): mc = MetricStoreCollection( @@ -25,10 +27,13 @@ def test_early_stop(): mc.update({"acc": 0.5, "loss": 0.2, "new_acc": 0.6}) mc.update({"acc": 0.7, "loss": 0.1, "new_acc": 0.8}) - es_condition = IITModelPair._check_early_stop_condition(mc.metrics) + ll_model, hl_model, corr, _, _ = get_test_model_pair_ingredients() + mod_pair = IITModelPair(hl_model, ll_model, corr) + + es_condition = mod_pair._check_early_stop_condition(mc.metrics) assert es_condition == False mc.update({"acc": 0.991, "loss": 0.1, "new_acc": 0.99}) - es_condition = IITModelPair._check_early_stop_condition(mc.metrics) + es_condition = mod_pair._check_early_stop_condition(mc.metrics) assert es_condition == False mc = MetricStoreCollection( [ @@ -37,8 +42,9 @@ def test_early_stop(): MetricStore("new_acc", MetricType.ACCURACY), ] ) - mc.update({"acc": 100, "loss": 0.1, "new_acc": 100}) - es_condition = IITModelPair._check_early_stop_condition(mc.metrics) + mc.update({"acc": 1, "loss": 0.1, "new_acc": 1}) + print(mc.metrics) + es_condition = mod_pair._check_early_stop_condition(mc.metrics) assert es_condition == True diff --git a/tests/test_model_pairs.py b/tests/test_model_pairs.py index 325f3d7..f9309bd 100644 --- a/tests/test_model_pairs.py +++ b/tests/test_model_pairs.py @@ -6,7 +6,7 @@ from iit.model_pairs.ll_model import LLModel import torch -def test_model_pair_gradients(): +def get_test_model_pair_ingredients(): ll_model = LLModel(cfg={ 'n_layers': 4, 'd_model': 32, @@ -30,13 +30,18 @@ def test_model_pair_gradients(): corr = Correspondence() hook_point = 'blocks.1.attn.hook_z' hook_idx = index.Ix[:, :, 0] - hook_idx_complement = index.Ix[:, :, 1:] - prev_hooks = ['blocks.0.attn.hook_z', 'blocks.0.mlp.hook_post'] - next_hooks = ['blocks.2.attn.hook_z', 'blocks.3.attn.hook_z', 'blocks.1.mlp.hook_post', 'blocks.2.mlp.hook_post', 'blocks.3.mlp.hook_post'] corr.update({ HLNode(hook_point, -1, index=hook_idx) : [LLNode(hook_point, index=hook_idx)], } ) + return ll_model, hl_model, corr, hook_point, hook_idx + +def test_model_pair_gradients(): + ll_model, hl_model, corr, hook_point, hook_idx = get_test_model_pair_ingredients() + + hook_idx_complement = index.Ix[:, :, 1:] + prev_hooks = ['blocks.0.attn.hook_z', 'blocks.0.mlp.hook_post'] + next_hooks = ['blocks.2.attn.hook_z', 'blocks.3.attn.hook_z', 'blocks.1.mlp.hook_post', 'blocks.2.mlp.hook_post', 'blocks.3.mlp.hook_post'] model_pair = CachingModelPair(ll_model=ll_model, hl_model=hl_model, corr=corr) diff --git a/tests/test_node_picker.py b/tests/test_node_picker.py index 749e799..84ec2fa 100644 --- a/tests/test_node_picker.py +++ b/tests/test_node_picker.py @@ -122,8 +122,8 @@ def test_suffix_maker(): "mlp": "mlp.hook_post" } - all_attns = [LLNode(f"blocks.{i}.attn.hook_result", index=None) for i in range(n_layers)] - all_mlps = [LLNode(f"blocks.{i}.mlp.hook_post", index=None) for i in range(n_layers)] + all_attns = [(f"blocks.{i}.attn.hook_result", Ix[[None]], None) for i in range(n_layers)] + all_mlps = [(f"blocks.{i}.mlp.hook_post", Ix[[None]], None) for i in range(n_layers)] corr_dict = { "all_nodes_hook": [all_mlps[3], all_attns[0]] @@ -134,4 +134,4 @@ def test_suffix_maker(): assert hook_suffixes == { "attn": "attn.hook_result", "mlp": "mlp.hook_post" - } \ No newline at end of file + } diff --git a/train.py b/train.py index 50fd456..d4e74fd 100644 --- a/train.py +++ b/train.py @@ -14,8 +14,9 @@ task = "mnist_pvr" train_set, test_set = get_dataset(task, dataset_config=dataset_config) ll_model, hl_model, corr = get_alignment( - task, config={"input_shape": test_set.base_data.get_input_shape()} + task, config={"input_shape": test_set.base_data.get_input_shape()} # type: ignore ) +assert ll_model is not None model_pair = IITBehaviorModelPair( ll_model=ll_model, hl_model=hl_model, corr=corr, training_args=training_args ) # TODO: add wrapper for choosing model pair diff --git a/train_ioi.py b/train_ioi.py index 9dc53a1..6439912 100644 --- a/train_ioi.py +++ b/train_ioi.py @@ -1,5 +1,6 @@ from iit.utils.train_scripts import train_ioi from iit.utils.io_scripts import save_model +from iit.utils.argparsing import IOIArgParseNamespace import torch as t if __name__ == "__main__": @@ -22,6 +23,7 @@ parser.add_argument("--include-mlp", action="store_true") args = parser.parse_args() + namespace = IOIArgParseNamespace(**vars(args)) - model_pair = train_ioi(args) - save_model(model_pair, args, "ioi") \ No newline at end of file + model_pair = train_ioi(namespace) + save_model(model_pair, namespace, "ioi") \ No newline at end of file