Skip to content

Commit

Permalink
remove redundancies and NAMES dependency for IOI
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Sep 7, 2024
1 parent 5214f10 commit cd832b8
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
if metric.type == MetricType.ACCURACY:
got_accuracy_metric = True
val = metric.get_value()
if isinstance(val, float) and val < 100:
if isinstance(val, float) and val < 99.5:
return False
if not got_accuracy_metric:
raise ValueError("No accuracy metric found in test_metrics!")
Expand Down
2 changes: 1 addition & 1 deletion iit/model_pairs/iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
"detach_while_caching": True,
"optimizer_cls": t.optim.Adam,
"optimizer_kwargs" : {
# "betas": (0.9, 0.9)
"betas": (0.9, 0.9)
},
}
training_args = {**default_training_args, **training_args}
Expand Down
22 changes: 13 additions & 9 deletions iit/tasks/ioi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@
}


def make_corr_dict(include_mlp: bool = False, eval: bool = False, use_pos_embed: bool = False) -> dict:
def make_corr_dict(
include_mlp: bool = False, eval: bool = False, use_pos_embed: bool = False
) -> dict:
all_attns = [f"blocks.{i}.attn.hook_z" for i in range(ioi_cfg["n_layers"])]
all_mlps = [f"blocks.{i}.mlp.hook_post" for i in range(ioi_cfg["n_layers"])]
attn_idx = Ix[:, :, 1]
if eval:
all_nodes_hook = "blocks.0.hook_resid_pre" if not use_pos_embed else "blocks.0.hook_pos_embed"
all_nodes_hook = (
"blocks.0.hook_resid_pre" if not use_pos_embed else "blocks.0.hook_pos_embed"
)
return {
"hook_duplicate": [[all_attns[1], Ix[[None]], None]],
"hook_duplicate": [[all_attns[1], attn_idx, None]],
# "hook_previous": ["blocks.1.attn.hook_result"],
"hook_s_inhibition": [[all_attns[2], Ix[[None]], None]],
"hook_name_mover": [[all_attns[4], Ix[[None]], None]],
"hook_s_inhibition": [[all_attns[2], attn_idx, None]],
"hook_name_mover": [[all_attns[4], attn_idx, None]],
"all_nodes_hook": (
[[all_nodes_hook, Ix[[None]], None], [all_mlps[0], Ix[[None]], None]]
if include_mlp
Expand All @@ -36,15 +41,14 @@ def make_corr_dict(include_mlp: bool = False, eval: bool = False, use_pos_embed:
"hook_out": [[f"blocks.{n_layers-1}.hook_resid_post", Ix[[None]], None]],
}
ans = {
"hook_duplicate": [[all_attns[1], Ix[[None]], None]],
"hook_duplicate": [[all_attns[1], attn_idx, None]],
# "hook_previous": ["blocks.1.attn.hook_result"],
"hook_s_inhibition": [[all_attns[2], Ix[[None]], None]],
"hook_name_mover": [[all_attns[4], Ix[[None]], None]],
"hook_s_inhibition": [[all_attns[2], attn_idx, None]],
"hook_name_mover": [[all_attns[4], attn_idx, None]],
}
if include_mlp:
ans["all_nodes_hook"] = [[all_mlps[0], Ix[[None]], None]]
return ans



corr_dict = make_corr_dict(include_mlp=False)
Expand Down
3 changes: 1 addition & 2 deletions iit/utils/eval_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from iit.model_pairs.ioi_model_pair import IOI_ModelPair
from iit.tasks.ioi import (
make_ioi_dataset_and_hl,
NAMES,
ioi_cfg,
make_corr_dict,
suffixes,
Expand Down Expand Up @@ -72,7 +71,7 @@ def eval_ioi(args: IOIArgParseNamespace) -> None:
np.random.seed(0)
t.manual_seed(0)
ioi_dataset, hl_model = make_ioi_dataset_and_hl(
num_samples, ll_model, NAMES, verbose=True
num_samples, ll_model, verbose=True
)

model_pair = IOI_ModelPair(ll_model=ll_model, hl_model=hl_model, corr=corr)
Expand Down
1 change: 0 additions & 1 deletion iit/utils/train_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from iit.model_pairs.base_model_pair import *
from iit.utils.metric import *
from iit.tasks.ioi import (
NAMES,
make_ioi_dataset_and_hl,
make_corr_dict,
ioi_cfg,
Expand Down

0 comments on commit cd832b8

Please sign in to comment.