From 6c5f4ed9bf094b000733d7230a72827e208871b6 Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Mon, 12 Aug 2024 15:56:31 -0700 Subject: [PATCH 1/9] improves training interface --- iit/model_pairs/base_model_pair.py | 77 +++++++++++++++++++----- iit/model_pairs/iit_model_pair.py | 3 +- iit/model_pairs/strict_iit_model_pair.py | 58 +++++++++++++----- 3 files changed, 105 insertions(+), 33 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index cf17c88..29bc497 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -1,5 +1,6 @@ +import os from abc import ABC, abstractmethod -from typing import Any, Callable, final, Type +from typing import Any, Callable, final, Type, Optional import numpy as np import torch as t @@ -8,6 +9,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm # type: ignore from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore +from IPython.display import clear_output import wandb # type: ignore from iit.model_pairs.ll_model import LLModel @@ -18,6 +20,24 @@ from iit.utils.index import Ix, TorchIndex from iit.utils.metric import MetricStoreCollection, MetricType +def in_notebook(): + try: + # This will only work in Jupyter notebooks + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other types of interactive shells + except NameError: + return False # Probably standard Python interpreter + +if in_notebook(): + from tqdm.notebook import tqdm +else: + from tqdm import tqdm + class BaseModelPair(ABC): hl_model: HookedRootModule @@ -246,11 +266,16 @@ def train( optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs) loss_fn = self.loss_fn scheduler_cls = training_args.get("lr_scheduler", None) + scheduler_kwargs = training_args.get("scheduler_kwargs", {}) if scheduler_cls == t.optim.lr_scheduler.ReduceLROnPlateau: mode = training_args.get("scheduler_mode", "max") - lr_scheduler = scheduler_cls(optimizer, mode=mode, factor=0.1, patience=10) + if 'patience' not in scheduler_kwargs: + scheduler_kwargs['patience'] = 10 + if 'factor' not in scheduler_kwargs: + scheduler_kwargs['factor'] = 0.1 + lr_scheduler = scheduler_cls(optimizer, mode=mode, **scheduler_kwargs) elif scheduler_cls: - lr_scheduler = scheduler_cls(optimizer) + lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs) if use_wandb and not wandb.run: wandb.init(project="iit", name=wandb_name_suffix, @@ -260,20 +285,33 @@ def train( wandb.config.update(training_args) wandb.config.update({"method": self.wandb_method}) wandb.run.log_code() # type: ignore + + epoch_pbar = tqdm(range(epochs), desc="Training Epochs") + batch_pbar = tqdm(total=len(train_loader), desc="Training Batches") + for epoch in range(epochs): + batch_pbar.reset() - for epoch in tqdm(range(epochs)): - train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer) + train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar) test_metrics = self._run_eval_epoch(test_loader, loss_fn) if scheduler_cls: self.step_scheduler(lr_scheduler, test_metrics) self.test_metrics = test_metrics self.train_metrics = train_metrics - self._print_and_log_metrics( - epoch, MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), use_wandb + current_epoch_log = self._print_and_log_metrics( + epoch, + MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), + optimizer, + use_wandb=use_wandb, ) + epoch_pbar.update(1) + epoch_pbar.set_postfix_str(current_epoch_log.strip(', ')) + epoch_pbar.set_description(f"Epoch {epoch + 1}/{epochs}") + if early_stop and self._check_early_stop_condition(test_metrics): break + epoch_pbar.close() + batch_pbar.close() if use_wandb: wandb.log({"final epoch": epoch}) @@ -298,16 +336,16 @@ def _run_train_epoch( self, loader: DataLoader, loss_fn: Callable[[Tensor, Tensor], Tensor], - optimizer: t.optim.Optimizer + optimizer: t.optim.Optimizer, + pbar: tqdm ) -> MetricStoreCollection: self.ll_model.train() train_metrics = self.make_train_metrics() - for i, (base_input, ablation_input) in tqdm( - enumerate(loader), total=len(loader) - ): + for i, (base_input, ablation_input) in enumerate(loader): train_metrics.update( self.run_train_step(base_input, ablation_input, loss_fn, optimizer) ) + pbar.update(1) return train_metrics @final @@ -345,13 +383,22 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo def _print_and_log_metrics( epoch: int, metrics: MetricStoreCollection, - use_wandb: bool = False + optimizer: t.optim.Optimizer, + use_wandb: bool = False, + print_metrics: bool = True, ) -> None: - print(f"\nEpoch {epoch}:", end=" ") + + # Print the current epoch's metrics + current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, " if use_wandb: wandb.log({"epoch": epoch}) + wandb.log({"lr": optimizer.param_groups[0]["lr"]}) + for metric in metrics: - print(metric, end=", ") + current_epoch_log += f"{metric}, " if use_wandb: wandb.log({metric.get_name(): metric.get_value()}) - print() + if print_metrics: + tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}') + + return current_epoch_log diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index 8cd720d..a619c11 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -23,8 +23,6 @@ def __init__( self.hl_model.requires_grad_(False) self.corr = corr - print(self.hl_model.hook_dict) - print(self.corr.keys()) assert all([str(k) in self.hl_model.hook_dict for k in self.corr.keys()]) default_training_args = { "batch_size": 256, @@ -34,6 +32,7 @@ def __init__( "lr_scheduler": None, "scheduler_val_metric": ["val/accuracy", "val/IIA"], "scheduler_mode": "max", + "scheduler_kwargs": {}, "clip_grad_norm": 1.0, "seed": 0, "detach_while_caching": True, diff --git a/iit/model_pairs/strict_iit_model_pair.py b/iit/model_pairs/strict_iit_model_pair.py index 9b7fc16..f928b40 100644 --- a/iit/model_pairs/strict_iit_model_pair.py +++ b/iit/model_pairs/strict_iit_model_pair.py @@ -54,6 +54,28 @@ def make_test_metrics() -> MetricStoreCollection: def sample_ll_node(self) -> LLNode: return self.rng.choice(np.array(self.nodes_not_in_circuit, dtype=object)) + + def get_SIIT_loss_over_batch( + self, + base_input: tuple[Tensor, Tensor, Tensor], + ablation_input: tuple[Tensor, Tensor, Tensor], + loss_fn: Callable[[Tensor, Tensor], Tensor] + ) -> Tensor: + base_x, base_y = base_input[0:2] + ablation_x, _ = ablation_input[0:2] + ll_node = self.sample_ll_node() + _, cache = self.ll_model.run_with_cache(ablation_x) + self.ll_cache = cache + out = self.ll_model.run_with_hooks( + base_x, fwd_hooks=[(ll_node.name, self.make_ll_ablation_hook(ll_node))] + ) + # print(out.shape, base_y.shape) + label_idx = self.get_label_idxs() + siit_loss = ( + loss_fn(out[label_idx.as_index], base_y[label_idx.as_index].to(self.ll_model.cfg.device)) + * self.training_args["strict_weight"] + ) # do this only for the tokens that we care about for IIT + return siit_loss def run_train_step( self, @@ -74,22 +96,26 @@ def run_train_step( # loss for nodes that are not in the circuit # should not have causal effect on the high-level output - base_x, base_y = base_input[0:2] - ablation_x, ablation_y = ablation_input[0:2] - ll_node = self.sample_ll_node() - _, cache = self.ll_model.run_with_cache(ablation_x) - self.ll_cache = cache - out = self.ll_model.run_with_hooks( - base_x, fwd_hooks=[(ll_node.name, self.make_ll_ablation_hook(ll_node))] - ) - # print(out.shape, base_y.shape) - label_idx = self.get_label_idxs() - ll_loss = ( - loss_fn(out[label_idx.as_index], base_y[label_idx.as_index].to(self.ll_model.cfg.device)) + siit_loss = ( + self.get_SIIT_loss_over_batch(base_input, ablation_input, loss_fn) * self.training_args["strict_weight"] - ) # do this only for the tokens that we care about for IIT + ) + # base_x, base_y = base_input[0:2] + # ablation_x, ablation_y = ablation_input[0:2] + # ll_node = self.sample_ll_node() + # _, cache = self.ll_model.run_with_cache(ablation_x) + # self.ll_cache = cache + # out = self.ll_model.run_with_hooks( + # base_x, fwd_hooks=[(ll_node.name, self.make_ll_ablation_hook(ll_node))] + # ) + # # print(out.shape, base_y.shape) + # label_idx = self.get_label_idxs() + # ll_loss = ( + # loss_fn(out[label_idx.as_index], base_y[label_idx.as_index].to(self.ll_model.cfg.device)) + # * self.training_args["strict_weight"] + # ) # do this only for the tokens that we care about for IIT if not use_single_loss: - self.step_on_loss(ll_loss, optimizer) + self.step_on_loss(siit_loss, optimizer) behavior_loss = ( self.get_behaviour_loss_over_batch(base_input, loss_fn) @@ -99,13 +125,13 @@ def run_train_step( self.step_on_loss(behavior_loss, optimizer) if use_single_loss: - total_loss = iit_loss + behavior_loss + ll_loss + total_loss = iit_loss + behavior_loss + siit_loss self.step_on_loss(total_loss, optimizer) return { "train/iit_loss": iit_loss.item(), "train/behavior_loss": behavior_loss.item(), - "train/strict_loss": ll_loss.item(), + "train/strict_loss": siit_loss.item(), } def run_eval_step( From 1c089cffa8dc463c1c1c8479a977838e92a8e2df Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Tue, 13 Aug 2024 17:21:18 -0700 Subject: [PATCH 2/9] A bunch of small changes to make training smoother; need to debug and type-check. --- iit/model_pairs/base_model_pair.py | 24 ++++++++++++++++------ iit/model_pairs/iit_behavior_model_pair.py | 10 +++++---- iit/model_pairs/iit_model_pair.py | 9 ++++++++ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 29bc497..c44acf7 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -197,7 +197,7 @@ def get_IIT_loss_over_batch( hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) label_idx = self.get_label_idxs() # IIT loss is only computed on the tokens we care about - loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index]) + loss = loss_fn(ll_output[label_idx.as_index].to(hl_output.device), hl_output[label_idx.as_index]) return loss def clip_grad_fn(self) -> None: @@ -298,9 +298,9 @@ def train( self.test_metrics = test_metrics self.train_metrics = train_metrics current_epoch_log = self._print_and_log_metrics( - epoch, - MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), - optimizer, + epoch=epoch, + metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), + optimizer=optimizer, use_wandb=use_wandb, ) @@ -310,6 +310,8 @@ def train( if early_stop and self._check_early_stop_condition(test_metrics): break + + self._run_epoch_extras(epoch_number=epoch+1) epoch_pbar.close() batch_pbar.close() @@ -379,8 +381,8 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo return True @final - @staticmethod def _print_and_log_metrics( + self, epoch: int, metrics: MetricStoreCollection, optimizer: t.optim.Optimizer, @@ -390,15 +392,25 @@ def _print_and_log_metrics( # Print the current epoch's metrics current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, " + for k in self.training_args.keys(): + if 'weight' in k and 'schedule' not in k: + current_epoch_log += f"{k}: {self.training_args[k]:.2e}, " if use_wandb: wandb.log({"epoch": epoch}) wandb.log({"lr": optimizer.param_groups[0]["lr"]}) for metric in metrics: - current_epoch_log += f"{metric}, " + if metric.type == MetricType.ACCURACY: + current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2f}, " + else: + current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2e}, " if use_wandb: wandb.log({metric.get_name(): metric.get_value()}) if print_metrics: tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}') return current_epoch_log + + @abstractmethod + def _run_epoch_extras(self, epoch_number: int) -> None: + pass diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index 4a74742..05b7605 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -56,8 +56,9 @@ def get_behaviour_loss_over_batch( ) -> Tensor: base_x, base_y = base_input[0:2] output = self.ll_model(base_x) - behavior_loss = loss_fn(output, base_y) - return behavior_loss + indx = self.get_label_idxs() + behaviour_loss = loss_fn(output[indx.as_index], base_y[indx.as_index].to(output.device)) + return behaviour_loss def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None: optimizer.zero_grad() @@ -110,7 +111,7 @@ def run_eval_step( label_idx = self.get_label_idxs() hl_node = self.sample_hl_name() hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) - hl_output.to(ll_output.device) + hl_output = hl_output.to(ll_output.device) hl_output = hl_output[label_idx.as_index] ll_output = ll_output[label_idx.as_index] if self.hl_model.is_categorical(): @@ -130,10 +131,11 @@ def run_eval_step( output = self.ll_model(base_x) if self.hl_model.is_categorical(): top1 = t.argmax(output, dim=-1) - if output.shape == base_y.shape: + if output.shape[-1] == base_y.shape[-1]: # To handle the case when labels are one-hot # TODO: is there a better way? base_y = t.argmax(base_y, dim=-1) + base_y = base_y.to(top1.device) accuracy = (top1 == base_y).float().mean() else: accuracy = ((output - base_y).abs() < atol).float().mean() diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index a619c11..535ab92 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -36,6 +36,9 @@ def __init__( "clip_grad_norm": 1.0, "seed": 0, "detach_while_caching": True, + "iit_weight_schedule" : lambda s, i: s, + "strict_weight_schedule" : lambda s, i: s, + "behavior_weight_schedule" : lambda s, i: s, } training_args = {**default_training_args, **training_args} if isinstance(ll_model, HookedRootModule): @@ -124,3 +127,9 @@ def run_train_step( loss.backward() # type: ignore optimizer.step() return {"train/iit_loss": loss.item()} + + + def _run_epoch_extras(self, epoch_number: int) -> None: + self.training_args['iit_weight'] = self.training_args['iit_weight_schedule'](self.training_args['iit_weight'], epoch_number) + self.training_args['strict_weight'] = self.training_args['strict_weight_schedule'](self.training_args['strict_weight'], epoch_number) + self.training_args['behavior_weight'] = self.training_args['behavior_weight_schedule'](self.training_args['behavior_weight'], epoch_number) \ No newline at end of file From 97c18d3a5fa30039f46fdbc15f9790186a03b7b6 Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Tue, 20 Aug 2024 17:50:32 -0700 Subject: [PATCH 3/9] New SIIT sampling protocol, Adds optimizer_kwargs to training_args --- iit/model_pairs/base_model_pair.py | 59 +++++++++++----------- iit/model_pairs/iit_behavior_model_pair.py | 41 ++++++++------- iit/model_pairs/iit_model_pair.py | 8 +-- iit/model_pairs/strict_iit_model_pair.py | 21 ++++++-- 4 files changed, 73 insertions(+), 56 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index c44acf7..c47c155 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -242,7 +242,6 @@ def train( epochs: int = 1000, use_wandb: bool = False, wandb_name_suffix: str = "", - optimizer_kwargs: dict = {}, ) -> None: training_args = self.training_args print(f"{training_args=}") @@ -262,8 +261,7 @@ def train( early_stop = training_args["early_stop"] - optimizer_kwargs['lr'] = training_args["lr"] - optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs) + optimizer = optimizer_cls(self.ll_model.parameters(), **training_args['optimizer_kwargs']) loss_fn = self.loss_fn scheduler_cls = training_args.get("lr_scheduler", None) scheduler_kwargs = training_args.get("scheduler_kwargs", {}) @@ -285,35 +283,36 @@ def train( wandb.config.update(training_args) wandb.config.update({"method": self.wandb_method}) wandb.run.log_code() # type: ignore - - epoch_pbar = tqdm(range(epochs), desc="Training Epochs") - batch_pbar = tqdm(total=len(train_loader), desc="Training Batches") - for epoch in range(epochs): - batch_pbar.reset() - - train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar) - test_metrics = self._run_eval_epoch(test_loader, loss_fn) - if scheduler_cls: - self.step_scheduler(lr_scheduler, test_metrics) - self.test_metrics = test_metrics - self.train_metrics = train_metrics - current_epoch_log = self._print_and_log_metrics( - epoch=epoch, - metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), - optimizer=optimizer, - use_wandb=use_wandb, - ) - epoch_pbar.update(1) - epoch_pbar.set_postfix_str(current_epoch_log.strip(', ')) - epoch_pbar.set_description(f"Epoch {epoch + 1}/{epochs}") + + # Set seed before iterating on loaders for reproduceablility. + t.manual_seed(training_args["seed"]) + with tqdm(range(epochs), desc="Training Epochs") as epoch_pbar: + with tqdm(total=len(train_loader), desc="Training Batches") as batch_pbar: + for epoch in range(epochs): + batch_pbar.reset() + + train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar) + test_metrics = self._run_eval_epoch(test_loader, loss_fn) + if scheduler_cls: + self.step_scheduler(lr_scheduler, test_metrics) + self.test_metrics = test_metrics + self.train_metrics = train_metrics + current_epoch_log = self._print_and_log_metrics( + epoch=epoch, + metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), + optimizer=optimizer, + use_wandb=use_wandb, + ) + + epoch_pbar.update(1) + epoch_pbar.set_postfix_str(current_epoch_log.strip(', ')) + epoch_pbar.set_description(f"Epoch {epoch + 1}/{epochs}") - if early_stop and self._check_early_stop_condition(test_metrics): - break - - self._run_epoch_extras(epoch_number=epoch+1) - epoch_pbar.close() - batch_pbar.close() + if early_stop and self._check_early_stop_condition(test_metrics): + break + + self._run_epoch_extras(epoch_number=epoch+1) if use_wandb: wandb.log({"final epoch": epoch}) diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index 05b7605..3a8a776 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -109,32 +109,37 @@ def run_eval_step( # compute IIT loss and accuracy label_idx = self.get_label_idxs() - hl_node = self.sample_hl_name() - hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) - hl_output = hl_output.to(ll_output.device) - hl_output = hl_output[label_idx.as_index] - ll_output = ll_output[label_idx.as_index] - if self.hl_model.is_categorical(): - loss = loss_fn(ll_output, hl_output) - if ll_output.shape == hl_output.shape: - # To handle the case when labels are one-hot - hl_output = t.argmax(hl_output, dim=-1) - top1 = t.argmax(ll_output, dim=-1) - accuracy = (top1 == hl_output).float().mean() - IIA = accuracy.item() - else: - loss = loss_fn(ll_output, hl_output) - IIA = ((ll_output - hl_output).abs() < atol).float().mean().item() + #don't just sample one HL node, compute IIA for all HL nodes and average. + iias = [] + for hl_node in self.corr.keys(): + # hl_node = self.sample_hl_name() + hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) + hl_output = hl_output.to(ll_output.device) + hl_output = hl_output[label_idx.as_index] + ll_output = ll_output[label_idx.as_index] + if self.hl_model.is_categorical(): + loss = loss_fn(ll_output, hl_output) + if ll_output.shape == hl_output.shape: + # To handle the case when labels are one-hot + hl_output = t.argmax(hl_output, dim=-1) + top1 = t.argmax(ll_output, dim=-1) + accuracy = (top1 == hl_output).float().mean() + IIA = accuracy.item() + else: + loss = loss_fn(ll_output, hl_output) + IIA = ((ll_output - hl_output).abs() < atol).float().mean().item() + iias.append(IIA) + IIA = sum(iias) / len(iias) # compute behavioral accuracy base_x, base_y = base_input[0:2] - output = self.ll_model(base_x) + output = self.ll_model(base_x)[label_idx.as_index] #convert ll logits -> one-hot max label if self.hl_model.is_categorical(): top1 = t.argmax(output, dim=-1) if output.shape[-1] == base_y.shape[-1]: # To handle the case when labels are one-hot # TODO: is there a better way? - base_y = t.argmax(base_y, dim=-1) + base_y = t.argmax(base_y, dim=-1).squeeze() base_y = base_y.to(top1.device) accuracy = (top1 == base_y).float().mean() else: diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index 535ab92..e1920c0 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -26,21 +26,21 @@ def __init__( assert all([str(k) in self.hl_model.hook_dict for k in self.corr.keys()]) default_training_args = { "batch_size": 256, - "lr": 0.001, "num_workers": 0, "early_stop": True, "lr_scheduler": None, "scheduler_val_metric": ["val/accuracy", "val/IIA"], "scheduler_mode": "max", "scheduler_kwargs": {}, + "optimizer_kwargs" : { + "lr": 0.001, + }, "clip_grad_norm": 1.0, "seed": 0, "detach_while_caching": True, - "iit_weight_schedule" : lambda s, i: s, - "strict_weight_schedule" : lambda s, i: s, - "behavior_weight_schedule" : lambda s, i: s, } training_args = {**default_training_args, **training_args} + if isinstance(ll_model, HookedRootModule): ll_model = LLModel.make_from_hooked_transformer( ll_model, detach_while_caching=training_args["detach_while_caching"] diff --git a/iit/model_pairs/strict_iit_model_pair.py b/iit/model_pairs/strict_iit_model_pair.py index f928b40..598e913 100644 --- a/iit/model_pairs/strict_iit_model_pair.py +++ b/iit/model_pairs/strict_iit_model_pair.py @@ -29,6 +29,7 @@ def __init__( "behavior_weight": 1.0, "strict_weight": 1.0, "clip_grad_norm": 1.0, + "siit_sampling" : "individual", # individual, sample_all, all } training_args = {**default_training_args, **training_args} super().__init__(hl_model, ll_model, corr=corr, training_args=training_args) @@ -52,8 +53,17 @@ def make_test_metrics() -> MetricStoreCollection: IITBehaviorModelPair.make_test_metrics().metrics + [MetricStore("val/strict_accuracy", MetricType.ACCURACY)], ) - def sample_ll_node(self) -> LLNode: - return self.rng.choice(np.array(self.nodes_not_in_circuit, dtype=object)) + def sample_ll_nodes(self) -> list[LLNode]: + if self.training_args['siit_sampling'] == 'individual': + ll_nodes = [self.rng.choice(np.array(self.nodes_not_in_circuit, dtype=object)),] + elif self.training_args['siit_sampling'] == 'sample_all': + importance = t.randint(0, 2, (len(self.nodes_not_in_circuit),)).to(bool).tolist() + ll_nodes = [node for node, imp in zip(self.nodes_not_in_circuit, importance) if imp] + elif self.training_args['siit_sampling'] == 'all': + ll_nodes = self.nodes_not_in_circuit + else: + raise ValueError(f"Unexpected SIIT sampling mode: {self.training_args['siit_sampling']}") + return ll_nodes def get_SIIT_loss_over_batch( self, @@ -63,11 +73,14 @@ def get_SIIT_loss_over_batch( ) -> Tensor: base_x, base_y = base_input[0:2] ablation_x, _ = ablation_input[0:2] - ll_node = self.sample_ll_node() + ll_nodes = self.sample_ll_nodes() _, cache = self.ll_model.run_with_cache(ablation_x) self.ll_cache = cache + hooks = [] + for ll_node in ll_nodes: + hooks.append((ll_node.name, self.make_ll_ablation_hook(ll_node))) out = self.ll_model.run_with_hooks( - base_x, fwd_hooks=[(ll_node.name, self.make_ll_ablation_hook(ll_node))] + base_x, fwd_hooks=hooks ) # print(out.shape, base_y.shape) label_idx = self.get_label_idxs() From 5e5eaa05b45d4690cd20b6ec535a1ea94af58142 Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Thu, 22 Aug 2024 14:53:38 -0700 Subject: [PATCH 4/9] more training fixups (args, etc) --- iit/model_pairs/base_model_pair.py | 2 +- iit/model_pairs/iit_behavior_model_pair.py | 10 +++++++--- iit/model_pairs/iit_model_pair.py | 6 ------ iit/model_pairs/strict_iit_model_pair.py | 15 +++++++-------- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index c47c155..07cc6ad 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -410,6 +410,6 @@ def _print_and_log_metrics( return current_epoch_log - @abstractmethod def _run_epoch_extras(self, epoch_number: int) -> None: + """ Optional method for running extra code at the end of each epoch """ pass diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index 3a8a776..ea0a6b7 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -19,12 +19,12 @@ def __init__( training_args: dict = {} ): default_training_args = { - "lr": 0.001, "atol": 5e-2, - "early_stop": True, "use_single_loss": False, "iit_weight": 1.0, "behavior_weight": 1.0, + "iit_weight_schedule" : lambda s, i: s, + "behavior_weight_schedule" : lambda s, i: s, } training_args = {**default_training_args, **training_args} super().__init__(hl_model, ll_model, corr=corr, training_args=training_args) @@ -158,4 +158,8 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo return metric.get_value() == 100 else: return super()._check_early_stop_condition(test_metrics) - return False \ No newline at end of file + return False + + def _run_epoch_extras(self, epoch_number: int) -> None: + self.training_args['iit_weight'] = self.training_args['iit_weight_schedule'](self.training_args['iit_weight'], epoch_number) + self.training_args['behavior_weight'] = self.training_args['behavior_weight_schedule'](self.training_args['behavior_weight'], epoch_number) \ No newline at end of file diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index e1920c0..d33bfe9 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -127,9 +127,3 @@ def run_train_step( loss.backward() # type: ignore optimizer.step() return {"train/iit_loss": loss.item()} - - - def _run_epoch_extras(self, epoch_number: int) -> None: - self.training_args['iit_weight'] = self.training_args['iit_weight_schedule'](self.training_args['iit_weight'], epoch_number) - self.training_args['strict_weight'] = self.training_args['strict_weight_schedule'](self.training_args['strict_weight'], epoch_number) - self.training_args['behavior_weight'] = self.training_args['behavior_weight_schedule'](self.training_args['behavior_weight'], epoch_number) \ No newline at end of file diff --git a/iit/model_pairs/strict_iit_model_pair.py b/iit/model_pairs/strict_iit_model_pair.py index 598e913..36a4721 100644 --- a/iit/model_pairs/strict_iit_model_pair.py +++ b/iit/model_pairs/strict_iit_model_pair.py @@ -21,14 +21,8 @@ def __init__( training_args: dict = {} ): default_training_args = { - "batch_size": 256, - "lr": 0.001, - "num_workers": 0, - "use_single_loss": False, - "iit_weight": 1.0, - "behavior_weight": 1.0, "strict_weight": 1.0, - "clip_grad_norm": 1.0, + "strict_weight_schedule" : lambda s, i: s, "siit_sampling" : "individual", # individual, sample_all, all } training_args = {**default_training_args, **training_args} @@ -194,4 +188,9 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo metrics_to_check.append(metric) if metric.get_name() == "val/IIA" and self.training_args["iit_weight"] > 0: metrics_to_check.append(metric) - return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check)) \ No newline at end of file + return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check)) + + + def _run_epoch_extras(self, epoch_number: int) -> None: + super()._run_epoch_extras(epoch_number) + self.training_args['strict_weight'] = self.training_args['strict_weight_schedule'](self.training_args['strict_weight'], epoch_number) \ No newline at end of file From 7734db566a529c61d91e64f76494fb98de3a60ad Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Thu, 22 Aug 2024 15:22:57 -0700 Subject: [PATCH 5/9] fixes typing errors --- iit/model_pairs/base_model_pair.py | 6 +++--- iit/model_pairs/strict_iit_model_pair.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 07cc6ad..cdf56c4 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -20,10 +20,10 @@ from iit.utils.index import Ix, TorchIndex from iit.utils.metric import MetricStoreCollection, MetricType -def in_notebook(): +def in_notebook() -> bool: try: # This will only work in Jupyter notebooks - shell = get_ipython().__class__.__name__ + shell = get_ipython().__class__.__name__ # type: ignore if shell == 'ZMQInteractiveShell': return True # Jupyter notebook or qtconsole elif shell == 'TerminalInteractiveShell': @@ -387,7 +387,7 @@ def _print_and_log_metrics( optimizer: t.optim.Optimizer, use_wandb: bool = False, print_metrics: bool = True, - ) -> None: + ) -> str: # Print the current epoch's metrics current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, " diff --git a/iit/model_pairs/strict_iit_model_pair.py b/iit/model_pairs/strict_iit_model_pair.py index 80b17d4..669a0e8 100644 --- a/iit/model_pairs/strict_iit_model_pair.py +++ b/iit/model_pairs/strict_iit_model_pair.py @@ -55,7 +55,7 @@ def sample_ll_nodes(self) -> list[LLNode]: if self.training_args['siit_sampling'] == 'individual': ll_nodes = [self.rng.choice(np.array(self.nodes_not_in_circuit, dtype=object)),] elif self.training_args['siit_sampling'] == 'sample_all': - importance = t.randint(0, 2, (len(self.nodes_not_in_circuit),)).to(bool).tolist() + importance = t.randint(0, 2, (len(self.nodes_not_in_circuit),)).to(t.bool).tolist() ll_nodes = [node for node, imp in zip(self.nodes_not_in_circuit, importance) if imp] elif self.training_args['siit_sampling'] == 'all': ll_nodes = self.nodes_not_in_circuit @@ -97,9 +97,9 @@ def run_train_step( ) -> dict: use_single_loss = self.training_args["use_single_loss"] - iit_loss = 0 - ll_loss = 0 - behavior_loss = 0 + iit_loss = t.zeros(1) + siit_loss = t.zeros(1) + behavior_loss = t.zeros(1) if self.training_args["iit_weight"] > 0: hl_node = self.sample_hl_name() # sample a high-level variable to ablate @@ -118,7 +118,7 @@ def run_train_step( * self.training_args["strict_weight"] ) if not use_single_loss: - self.step_on_loss(ll_loss, optimizer) + self.step_on_loss(siit_loss, optimizer) if self.training_args["behavior_weight"] > 0: behavior_loss = ( From 18e467e45ecf28ff3b8b33caf0f5416614d9b614 Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Thu, 22 Aug 2024 15:28:14 -0700 Subject: [PATCH 6/9] Defaults: use_single_step = True, betas = (0.9, 0.9); adds optimzier_cls to training_args --- iit/model_pairs/base_model_pair.py | 3 +-- iit/model_pairs/iit_behavior_model_pair.py | 2 +- iit/model_pairs/iit_model_pair.py | 8 +++++--- iit/model_pairs/probed_sequential_pair.py | 5 ++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index cdf56c4..56fa0f8 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -238,7 +238,6 @@ def train( self, train_set: IITDataset, test_set: IITDataset, - optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam, epochs: int = 1000, use_wandb: bool = False, wandb_name_suffix: str = "", @@ -261,7 +260,7 @@ def train( early_stop = training_args["early_stop"] - optimizer = optimizer_cls(self.ll_model.parameters(), **training_args['optimizer_kwargs']) + optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **training_args['optimizer_kwargs']) loss_fn = self.loss_fn scheduler_cls = training_args.get("lr_scheduler", None) scheduler_kwargs = training_args.get("scheduler_kwargs", {}) diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index ea0a6b7..1a40abc 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -20,7 +20,7 @@ def __init__( ): default_training_args = { "atol": 5e-2, - "use_single_loss": False, + "use_single_loss": True, "iit_weight": 1.0, "behavior_weight": 1.0, "iit_weight_schedule" : lambda s, i: s, diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index d33bfe9..0e5ccec 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -32,12 +32,14 @@ def __init__( "scheduler_val_metric": ["val/accuracy", "val/IIA"], "scheduler_mode": "max", "scheduler_kwargs": {}, - "optimizer_kwargs" : { - "lr": 0.001, - }, "clip_grad_norm": 1.0, "seed": 0, "detach_while_caching": True, + "optimizer_cls": t.optim.Adam, + "optimizer_kwargs" : { + "lr": 0.001, + "betas": (0.9, 0.9) + }, } training_args = {**default_training_args, **training_args} diff --git a/iit/model_pairs/probed_sequential_pair.py b/iit/model_pairs/probed_sequential_pair.py index 88c4bf1..255ad63 100644 --- a/iit/model_pairs/probed_sequential_pair.py +++ b/iit/model_pairs/probed_sequential_pair.py @@ -82,7 +82,6 @@ def train( self, train_set: IITDataset, test_set: IITDataset, - optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam, epochs: int = 1000, use_wandb: bool = False, wandb_name_suffix: str = "", @@ -107,9 +106,9 @@ def train( for p in probes.values(): params += list(p.parameters()) optimizer_kwargs['lr'] = training_args["lr"] - probe_optimizer = optimizer_cls(params, **optimizer_kwargs) + probe_optimizer = training_args['optimizer_cls'](params, **optimizer_kwargs) - optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs) + optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **optimizer_kwargs) loss_fn = t.nn.CrossEntropyLoss() if use_wandb and not wandb.run: From 0d1b24f3825a49c897e8a18ac2e50d57e1ac191f Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Fri, 23 Aug 2024 09:09:00 -0700 Subject: [PATCH 7/9] cleans up PR per feedback --- iit/model_pairs/base_model_pair.py | 43 ++++++++-------------- iit/model_pairs/iit_behavior_model_pair.py | 41 ++++++++++++--------- iit/model_pairs/iit_model_pair.py | 1 - iit/model_pairs/strict_iit_model_pair.py | 14 ++----- iit/utils/tqdm.py | 17 +++++++++ 5 files changed, 60 insertions(+), 56 deletions(-) create mode 100644 iit/utils/tqdm.py diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 56fa0f8..1c3cceb 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -9,7 +9,6 @@ from torch.utils.data import DataLoader from tqdm import tqdm # type: ignore from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore -from IPython.display import clear_output import wandb # type: ignore from iit.model_pairs.ll_model import LLModel @@ -19,24 +18,8 @@ from iit.utils.iit_dataset import IITDataset from iit.utils.index import Ix, TorchIndex from iit.utils.metric import MetricStoreCollection, MetricType +from iit.utils.tqdm import tqdm -def in_notebook() -> bool: - try: - # This will only work in Jupyter notebooks - shell = get_ipython().__class__.__name__ # type: ignore - if shell == 'ZMQInteractiveShell': - return True # Jupyter notebook or qtconsole - elif shell == 'TerminalInteractiveShell': - return False # Terminal running IPython - else: - return False # Other types of interactive shells - except NameError: - return False # Probably standard Python interpreter - -if in_notebook(): - from tqdm.notebook import tqdm -else: - from tqdm import tqdm class BaseModelPair(ABC): @@ -197,7 +180,7 @@ def get_IIT_loss_over_batch( hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) label_idx = self.get_label_idxs() # IIT loss is only computed on the tokens we care about - loss = loss_fn(ll_output[label_idx.as_index].to(hl_output.device), hl_output[label_idx.as_index]) + loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index]) return loss def clip_grad_fn(self) -> None: @@ -251,6 +234,10 @@ def train( assert isinstance(test_set, IITDataset), ValueError( f"test_set is not an instance of IITDataset, but {type(test_set)}" ) + assert self.ll_model.device == self.hl_model.device, ValueError( + "ll_model and hl_model are not on the same device" + ) + train_loader, test_loader = self.make_loaders( train_set, test_set, @@ -297,17 +284,14 @@ def train( self.step_scheduler(lr_scheduler, test_metrics) self.test_metrics = test_metrics self.train_metrics = train_metrics - current_epoch_log = self._print_and_log_metrics( + self._print_and_log_metrics( epoch=epoch, metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), optimizer=optimizer, use_wandb=use_wandb, + epoch_pbar=epoch_pbar ) - epoch_pbar.update(1) - epoch_pbar.set_postfix_str(current_epoch_log.strip(', ')) - epoch_pbar.set_description(f"Epoch {epoch + 1}/{epochs}") - if early_stop and self._check_early_stop_condition(test_metrics): break @@ -386,6 +370,7 @@ def _print_and_log_metrics( optimizer: t.optim.Optimizer, use_wandb: bool = False, print_metrics: bool = True, + epoch_pbar: Optional[tqdm] = None, ) -> str: # Print the current epoch's metrics @@ -398,15 +383,17 @@ def _print_and_log_metrics( wandb.log({"lr": optimizer.param_groups[0]["lr"]}) for metric in metrics: - if metric.type == MetricType.ACCURACY: - current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2f}, " - else: - current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2e}, " + current_epoch_log += str(metric) + ", " if use_wandb: wandb.log({metric.get_name(): metric.get_value()}) if print_metrics: tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}') + if epoch_pbar is not None: + epoch_pbar.update(1) + epoch_pbar.set_postfix_str(current_epoch_log.strip(', ')) + epoch_pbar.set_description(f"Epoch {epoch + 1}") + return current_epoch_log def _run_epoch_extras(self, epoch_number: int) -> None: diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index 1a40abc..fa927a0 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -8,6 +8,7 @@ from iit.model_pairs.ll_model import LLModel from iit.utils.correspondence import Correspondence from iit.utils.metric import MetricStore, MetricStoreCollection, MetricType +from iit.utils.nodes import HLNode class IITBehaviorModelPair(IITModelPair): @@ -20,11 +21,10 @@ def __init__( ): default_training_args = { "atol": 5e-2, - "use_single_loss": True, + "use_single_loss": False, "iit_weight": 1.0, "behavior_weight": 1.0, - "iit_weight_schedule" : lambda s, i: s, - "behavior_weight_schedule" : lambda s, i: s, + "val_IIA_sampling": "random", # random or all } training_args = {**default_training_args, **training_args} super().__init__(hl_model, ll_model, corr=corr, training_args=training_args) @@ -56,8 +56,8 @@ def get_behaviour_loss_over_batch( ) -> Tensor: base_x, base_y = base_input[0:2] output = self.ll_model(base_x) - indx = self.get_label_idxs() - behaviour_loss = loss_fn(output[indx.as_index], base_y[indx.as_index].to(output.device)) + label_indx = self.get_label_idxs() + behaviour_loss = loss_fn(output[label_indx.as_index], base_y[label_indx.as_index].to(output.device)) return behaviour_loss def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None: @@ -109,10 +109,8 @@ def run_eval_step( # compute IIT loss and accuracy label_idx = self.get_label_idxs() - #don't just sample one HL node, compute IIA for all HL nodes and average. - iias = [] - for hl_node in self.corr.keys(): - # hl_node = self.sample_hl_name() + + def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]: hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) hl_output = hl_output.to(ll_output.device) hl_output = hl_output[label_idx.as_index] @@ -128,8 +126,22 @@ def run_eval_step( else: loss = loss_fn(ll_output, hl_output) IIA = ((ll_output - hl_output).abs() < atol).float().mean().item() - iias.append(IIA) - IIA = sum(iias) / len(iias) + return IIA, loss + + if self.training_args["val_IIA_sampling"] == "random": + hl_node = self.sample_hl_name() + IIA, loss = get_node_IIT_info(hl_node) + elif self.training_args["val_IIA_sampling"] == "all": + iias = [] + losses = [] + for hl_node in self.corr.keys(): + IIA, loss = get_node_IIT_info(hl_node) + iias.append(IIA) + losses.append(loss) + IIA = sum(iias) / len(iias) + loss = t.cat(losses).mean() + else: + raise ValueError(f"Invalid val_IIA_sampling: {self.training_args['val_IIA_sampling']}") # compute behavioral accuracy base_x, base_y = base_input[0:2] @@ -140,7 +152,6 @@ def run_eval_step( # To handle the case when labels are one-hot # TODO: is there a better way? base_y = t.argmax(base_y, dim=-1).squeeze() - base_y = base_y.to(top1.device) accuracy = (top1 == base_y).float().mean() else: accuracy = ((output - base_y).abs() < atol).float().mean() @@ -158,8 +169,4 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo return metric.get_value() == 100 else: return super()._check_early_stop_condition(test_metrics) - return False - - def _run_epoch_extras(self, epoch_number: int) -> None: - self.training_args['iit_weight'] = self.training_args['iit_weight_schedule'](self.training_args['iit_weight'], epoch_number) - self.training_args['behavior_weight'] = self.training_args['behavior_weight_schedule'](self.training_args['behavior_weight'], epoch_number) \ No newline at end of file + return False \ No newline at end of file diff --git a/iit/model_pairs/iit_model_pair.py b/iit/model_pairs/iit_model_pair.py index 0e5ccec..e866932 100644 --- a/iit/model_pairs/iit_model_pair.py +++ b/iit/model_pairs/iit_model_pair.py @@ -38,7 +38,6 @@ def __init__( "optimizer_cls": t.optim.Adam, "optimizer_kwargs" : { "lr": 0.001, - "betas": (0.9, 0.9) }, } training_args = {**default_training_args, **training_args} diff --git a/iit/model_pairs/strict_iit_model_pair.py b/iit/model_pairs/strict_iit_model_pair.py index 669a0e8..bb055bd 100644 --- a/iit/model_pairs/strict_iit_model_pair.py +++ b/iit/model_pairs/strict_iit_model_pair.py @@ -22,7 +22,6 @@ def __init__( ): default_training_args = { "strict_weight": 1.0, - "strict_weight_schedule" : lambda s, i: s, "siit_sampling" : "individual", # individual, sample_all, all } training_args = {**default_training_args, **training_args} @@ -133,9 +132,9 @@ def run_train_step( self.step_on_loss(total_loss, optimizer) return { - "train/iit_loss": iit_loss.item() if isinstance(iit_loss, Tensor) else iit_loss, - "train/behavior_loss": behavior_loss.item() if isinstance(behavior_loss, Tensor) else behavior_loss, - "train/strict_loss": siit_loss.item() if isinstance(siit_loss, Tensor) else siit_loss, + "train/iit_loss": iit_loss.item(), + "train/behavior_loss": behavior_loss.item(), + "train/strict_loss": siit_loss.item(), } def run_eval_step( @@ -190,9 +189,4 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo metrics_to_check.append(metric) if metric.get_name() == "val/IIA" and self.training_args["iit_weight"] > 0: metrics_to_check.append(metric) - return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check)) - - - def _run_epoch_extras(self, epoch_number: int) -> None: - super()._run_epoch_extras(epoch_number) - self.training_args['strict_weight'] = self.training_args['strict_weight_schedule'](self.training_args['strict_weight'], epoch_number) \ No newline at end of file + return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check)) \ No newline at end of file diff --git a/iit/utils/tqdm.py b/iit/utils/tqdm.py new file mode 100644 index 0000000..69e5f86 --- /dev/null +++ b/iit/utils/tqdm.py @@ -0,0 +1,17 @@ +def in_notebook() -> bool: + try: + # This will only work in Jupyter notebooks + shell = get_ipython().__class__.__name__ # type: ignore + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other types of interactive shells + except NameError: + return False # Probably standard Python interpreter + +if in_notebook(): + from tqdm.notebook import tqdm +else: + from tqdm import tqdm \ No newline at end of file From 313309e3e7c859a68093a7f034414fa9b0a6f145 Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Fri, 23 Aug 2024 09:25:43 -0700 Subject: [PATCH 8/9] reverts some indexing in IITBehaviorModelPair that caused circuits-bench to break --- iit/model_pairs/iit_behavior_model_pair.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index fa927a0..380815b 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -145,13 +145,13 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]: # compute behavioral accuracy base_x, base_y = base_input[0:2] - output = self.ll_model(base_x)[label_idx.as_index] #convert ll logits -> one-hot max label + output = self.ll_model(base_x) if self.hl_model.is_categorical(): top1 = t.argmax(output, dim=-1) - if output.shape[-1] == base_y.shape[-1]: + if output.shape == base_y.shape: # To handle the case when labels are one-hot # TODO: is there a better way? - base_y = t.argmax(base_y, dim=-1).squeeze() + base_y = t.argmax(base_y, dim=-1) accuracy = (top1 == base_y).float().mean() else: accuracy = ((output - base_y).abs() < atol).float().mean() From e4eba0adb2bb0aec3cab36896235a56c307ffd95 Mon Sep 17 00:00:00 2001 From: Evan Anders Date: Fri, 23 Aug 2024 09:42:09 -0700 Subject: [PATCH 9/9] bug fix and adds back in one .to(device) --- iit/model_pairs/iit_behavior_model_pair.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/iit/model_pairs/iit_behavior_model_pair.py b/iit/model_pairs/iit_behavior_model_pair.py index 380815b..c871466 100644 --- a/iit/model_pairs/iit_behavior_model_pair.py +++ b/iit/model_pairs/iit_behavior_model_pair.py @@ -139,12 +139,13 @@ def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]: iias.append(IIA) losses.append(loss) IIA = sum(iias) / len(iias) - loss = t.cat(losses).mean() + loss = t.stack(losses).mean() else: raise ValueError(f"Invalid val_IIA_sampling: {self.training_args['val_IIA_sampling']}") # compute behavioral accuracy base_x, base_y = base_input[0:2] + base_y = base_y.to(self.ll_model.device) # so that input data doesn't all need to be hogging room on device. output = self.ll_model(base_x) if self.hl_model.is_categorical(): top1 = t.argmax(output, dim=-1)