diff --git a/iit/tasks/ioi/__init__.py b/iit/tasks/ioi/__init__.py index c280470..18cb9b7 100644 --- a/iit/tasks/ioi/__init__.py +++ b/iit/tasks/ioi/__init__.py @@ -23,7 +23,7 @@ def make_corr_dict( ) -> dict: all_attns = [f"blocks.{i}.attn.hook_z" for i in range(ioi_cfg["n_layers"])] all_mlps = [f"blocks.{i}.mlp.hook_post" for i in range(ioi_cfg["n_layers"])] - attn_idx = Ix[:, :, 1] + attn_idx = Ix[:, :, 1, :] if eval: all_nodes_hook = ( "blocks.0.hook_resid_pre" if not use_pos_embed else "blocks.0.hook_pos_embed" diff --git a/iit/tasks/ioi/ioi_hl.py b/iit/tasks/ioi/ioi_hl.py index 4068f00..e6c4b56 100644 --- a/iit/tasks/ioi/ioi_hl.py +++ b/iit/tasks/ioi/ioi_hl.py @@ -9,15 +9,18 @@ class DuplicateHead(t.nn.Module): - def forward(self, tokens : Tensor) -> Tensor: + def forward(self, tokens: Tensor) -> Tensor: # Write the last previous position of any duplicated token (used at S2) - positions = (tokens[..., None, :] == tokens[..., :, None]) # batch seq1 seq2 - positions = t.triu(positions, diagonal=1) # only consider positions before this one + positions = tokens[..., None, :] == tokens[..., :, None] # batch seq1 seq2 + positions = t.triu( + positions, diagonal=1 + ) # only consider positions before this one indices = positions.nonzero(as_tuple=True) ret = t.full_like(tokens, -1) ret[indices[0], indices[2]] = indices[1] return ret - + + class PreviousHead(t.nn.Module): def forward(self, tokens: Tensor) -> Tensor: # copy token S1 to token S1+1 (used at S1+1) @@ -25,14 +28,15 @@ def forward(self, tokens: Tensor) -> Tensor: ret[..., 1:] = tokens[..., :-1] return ret + class InductionHead(t.nn.Module): """Induction heads omitted because they're redundant with duplicate heads in IOI""" - + class SInhibitionHead(t.nn.Module): def forward(self, tokens: Tensor, duplicate: Tensor) -> Tensor: """ - when duplicate is not -1, + when duplicate is not -1, output a flag to the name mover head to NOT copy this name flag is -1 if no duplicate name here, and name token for the name to inhibit """ @@ -43,17 +47,25 @@ def forward(self, tokens: Tensor, duplicate: Tensor) -> Tensor: # extract token positions we care about from duplicate duplicate_pos_at_duplicates = t.where(duplicate != -1) - duplicate_pos_at_tokens = duplicate[duplicate_pos_at_duplicates[0], duplicate_pos_at_duplicates[1]] - duplicate_pos_at_tokens_tup = (duplicate_pos_at_duplicates[0], duplicate_pos_at_tokens) + duplicate_pos_at_tokens = duplicate[ + duplicate_pos_at_duplicates[0], duplicate_pos_at_duplicates[1] + ] + duplicate_pos_at_tokens_tup = ( + duplicate_pos_at_duplicates[0], + duplicate_pos_at_tokens, + ) duplicate_tokens = tokens[duplicate_pos_at_tokens_tup] - assert ret[duplicate_pos_at_duplicates].abs().sum() == 0 # sanity check, to make sure we're not overwriting anything + assert ( + ret[duplicate_pos_at_duplicates].abs().sum() == 0 + ) # sanity check, to make sure we're not overwriting anything # replace ret with the duplicated tokens ret[duplicate_pos_at_duplicates] = duplicate_tokens - + return ret - + + class NameMoverHead(t.nn.Module): - def __init__(self, names: Tensor, d_vocab : int=40): + def __init__(self, names: Tensor, d_vocab: int = 40): super().__init__() self.d_vocab_out = d_vocab self.names = names @@ -63,22 +75,30 @@ def forward(self, tokens: Tensor, s_inhibition: Tensor) -> Tensor: increase logit of all names in the sentence, except those flagged by s_inhibition """ batch, seq = tokens.shape - logits = t.zeros((batch, seq, self.d_vocab_out), device=tokens.device) # batch seq d_vocab + logits = t.zeros( + (batch, seq, self.d_vocab_out), device=tokens.device + ) # batch seq d_vocab # we want every name to increase its corresponding logit after it appears name_mask = t.isin(tokens, self.names) - - batch_indices, seq_indices = t.meshgrid(t.arange(batch), t.arange(seq), indexing='ij') + + batch_indices, seq_indices = t.meshgrid( + t.arange(batch), t.arange(seq), indexing="ij" + ) logits[batch_indices, seq_indices, tokens] = 10 * name_mask.float() # now decrease the logit of the names that are inhibited - logits[batch_indices, seq_indices, s_inhibition] += -15 * s_inhibition.ne(-1).float() + logits[batch_indices, seq_indices, s_inhibition] += ( + -15 * s_inhibition.ne(-1).float() + ) logits = t.cumsum(logits, dim=1) return logits - + + # since 0, 3 contains 20, we write # a 1 to position 0, 3, 20 of logits - + # %% - + + class IOI_HL(HookedRootModule, HLModel): """ Components: @@ -88,14 +108,18 @@ class IOI_HL(HookedRootModule, HLModel): - S-inhibition heads: Inhibit attention of Name Mover Heads to S1 and S2 tokens - Name mover heads: Copy all previous names in the sentence """ + def __init__( - self, - d_vocab: int, - names: Tensor, - device: t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu") - ): + self, + d_vocab: int, + names: Tensor, + device: t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu"), + return_one_hot: bool = True, + ): super().__init__() - assert isinstance(names, Tensor), ValueError(f"Expected a tensor, got {type(names)}") + assert isinstance(names, Tensor), ValueError( + f"Expected a tensor, got {type(names)}" + ) self.all_nodes_hook = HookPoint() self.duplicate_head = DuplicateHead() self.hook_duplicate = HookPoint() @@ -110,19 +134,22 @@ def __init__( self.cfg = Namespace( d_vocab=d_vocab, d_vocab_out=d_vocab, - device=device + device=device, + return_one_hot=return_one_hot, ) self.setup() - + @property def device(self) -> t.device: return self.cfg.device def is_categorical(self) -> bool: return True - + def forward(self, args: Tensor | tuple, verbose: bool = False) -> Tensor: - show: Callable[[t.Any], None] = lambda *args, **kwargs: print(*args, **kwargs) if verbose else None + show: Callable[[t.Any], None] = lambda *args, **kwargs: ( + print(*args, **kwargs) if verbose else None + ) if isinstance(args, Tensor): input = args elif isinstance(args, tuple): @@ -150,10 +177,16 @@ def forward(self, args: Tensor | tuple, verbose: bool = False) -> Tensor: s_inhibition = self.hook_s_inhibition(s_inhibition) show(f"s_inhibition: {s_inhibition}") out = self.name_mover_head(input, s_inhibition) + if self.cfg.return_one_hot: + out = t.nn.functional.one_hot( + t.argmax(out, dim=-1), num_classes=self.d_vocab + ).float() assert out.shape == input.shape + (self.d_vocab,) out = self.hook_name_mover(out) show(f"out: {t.argmax(out, dim=-1)}") if not batched: out = out[0] return out + + # %% diff --git a/iit/tasks/ioi/utils.py b/iit/tasks/ioi/utils.py index 2fc4efc..461b26b 100644 --- a/iit/tasks/ioi/utils.py +++ b/iit/tasks/ioi/utils.py @@ -28,8 +28,8 @@ def make_ioi_dataset_and_hl( ) ioi_names = t.tensor( - list(set([ioi_dataset_tl[i]["IO"].item() for i in range(len(ioi_dataset_tl))])) - ).to(device) + [ll_model.tokenizer.encode(" " + name) for name in ioi_dataset_tl.names] + ).flatten() hl_model = IOI_HL(d_vocab=ll_model.cfg.d_vocab_out, names=ioi_names, device=device) ioi_dataset = IOIDatasetWrapper( diff --git a/tests/test_ioi_hl.py b/tests/test_ioi_hl.py index 8a7d01a..5e6173a 100644 --- a/tests/test_ioi_hl.py +++ b/tests/test_ioi_hl.py @@ -70,7 +70,7 @@ def test_name_mover_head() -> None: def test_ioi_hl() -> None: - a = IOI_HL(d_vocab=21, names=IOI_TEST_NAMES)( + a = IOI_HL(d_vocab=21, names=IOI_TEST_NAMES, return_one_hot=False)( (t.tensor([[3, 10, 4, 10, 5, 9, 2, 6, 5]]), None, None) ) assert nonzero_values(a[0]).equal(