Skip to content

Commit

Permalink
add option to return one-hot for ioi
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Sep 7, 2024
1 parent cd832b8 commit e4c4448
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 33 deletions.
2 changes: 1 addition & 1 deletion iit/tasks/ioi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def make_corr_dict(
) -> dict:
all_attns = [f"blocks.{i}.attn.hook_z" for i in range(ioi_cfg["n_layers"])]
all_mlps = [f"blocks.{i}.mlp.hook_post" for i in range(ioi_cfg["n_layers"])]
attn_idx = Ix[:, :, 1]
attn_idx = Ix[:, :, 1, :]
if eval:
all_nodes_hook = (
"blocks.0.hook_resid_pre" if not use_pos_embed else "blocks.0.hook_pos_embed"
Expand Down
91 changes: 62 additions & 29 deletions iit/tasks/ioi/ioi_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,34 @@


class DuplicateHead(t.nn.Module):
def forward(self, tokens : Tensor) -> Tensor:
def forward(self, tokens: Tensor) -> Tensor:
# Write the last previous position of any duplicated token (used at S2)
positions = (tokens[..., None, :] == tokens[..., :, None]) # batch seq1 seq2
positions = t.triu(positions, diagonal=1) # only consider positions before this one
positions = tokens[..., None, :] == tokens[..., :, None] # batch seq1 seq2
positions = t.triu(
positions, diagonal=1
) # only consider positions before this one
indices = positions.nonzero(as_tuple=True)
ret = t.full_like(tokens, -1)
ret[indices[0], indices[2]] = indices[1]
return ret



class PreviousHead(t.nn.Module):
def forward(self, tokens: Tensor) -> Tensor:
# copy token S1 to token S1+1 (used at S1+1)
ret = t.full_like(tokens, -1)
ret[..., 1:] = tokens[..., :-1]
return ret


class InductionHead(t.nn.Module):
"""Induction heads omitted because they're redundant with duplicate heads in IOI"""


class SInhibitionHead(t.nn.Module):
def forward(self, tokens: Tensor, duplicate: Tensor) -> Tensor:
"""
when duplicate is not -1,
when duplicate is not -1,
output a flag to the name mover head to NOT copy this name
flag is -1 if no duplicate name here, and name token for the name to inhibit
"""
Expand All @@ -43,17 +47,25 @@ def forward(self, tokens: Tensor, duplicate: Tensor) -> Tensor:

# extract token positions we care about from duplicate
duplicate_pos_at_duplicates = t.where(duplicate != -1)
duplicate_pos_at_tokens = duplicate[duplicate_pos_at_duplicates[0], duplicate_pos_at_duplicates[1]]
duplicate_pos_at_tokens_tup = (duplicate_pos_at_duplicates[0], duplicate_pos_at_tokens)
duplicate_pos_at_tokens = duplicate[
duplicate_pos_at_duplicates[0], duplicate_pos_at_duplicates[1]
]
duplicate_pos_at_tokens_tup = (
duplicate_pos_at_duplicates[0],
duplicate_pos_at_tokens,
)
duplicate_tokens = tokens[duplicate_pos_at_tokens_tup]
assert ret[duplicate_pos_at_duplicates].abs().sum() == 0 # sanity check, to make sure we're not overwriting anything
assert (
ret[duplicate_pos_at_duplicates].abs().sum() == 0
) # sanity check, to make sure we're not overwriting anything
# replace ret with the duplicated tokens
ret[duplicate_pos_at_duplicates] = duplicate_tokens

return ret



class NameMoverHead(t.nn.Module):
def __init__(self, names: Tensor, d_vocab : int=40):
def __init__(self, names: Tensor, d_vocab: int = 40):
super().__init__()
self.d_vocab_out = d_vocab
self.names = names
Expand All @@ -63,22 +75,30 @@ def forward(self, tokens: Tensor, s_inhibition: Tensor) -> Tensor:
increase logit of all names in the sentence, except those flagged by s_inhibition
"""
batch, seq = tokens.shape
logits = t.zeros((batch, seq, self.d_vocab_out), device=tokens.device) # batch seq d_vocab
logits = t.zeros(
(batch, seq, self.d_vocab_out), device=tokens.device
) # batch seq d_vocab
# we want every name to increase its corresponding logit after it appears
name_mask = t.isin(tokens, self.names)

batch_indices, seq_indices = t.meshgrid(t.arange(batch), t.arange(seq), indexing='ij')

batch_indices, seq_indices = t.meshgrid(
t.arange(batch), t.arange(seq), indexing="ij"
)
logits[batch_indices, seq_indices, tokens] = 10 * name_mask.float()
# now decrease the logit of the names that are inhibited
logits[batch_indices, seq_indices, s_inhibition] += -15 * s_inhibition.ne(-1).float()
logits[batch_indices, seq_indices, s_inhibition] += (
-15 * s_inhibition.ne(-1).float()
)
logits = t.cumsum(logits, dim=1)
return logits



# since 0, 3 contains 20, we write
# a 1 to position 0, 3, 20 of logits

# %%



class IOI_HL(HookedRootModule, HLModel):
"""
Components:
Expand All @@ -88,14 +108,18 @@ class IOI_HL(HookedRootModule, HLModel):
- S-inhibition heads: Inhibit attention of Name Mover Heads to S1 and S2 tokens
- Name mover heads: Copy all previous names in the sentence
"""

def __init__(
self,
d_vocab: int,
names: Tensor,
device: t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
):
self,
d_vocab: int,
names: Tensor,
device: t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu"),
return_one_hot: bool = True,
):
super().__init__()
assert isinstance(names, Tensor), ValueError(f"Expected a tensor, got {type(names)}")
assert isinstance(names, Tensor), ValueError(
f"Expected a tensor, got {type(names)}"
)
self.all_nodes_hook = HookPoint()
self.duplicate_head = DuplicateHead()
self.hook_duplicate = HookPoint()
Expand All @@ -110,19 +134,22 @@ def __init__(
self.cfg = Namespace(
d_vocab=d_vocab,
d_vocab_out=d_vocab,
device=device
device=device,
return_one_hot=return_one_hot,
)
self.setup()

@property
def device(self) -> t.device:
return self.cfg.device

def is_categorical(self) -> bool:
return True

def forward(self, args: Tensor | tuple, verbose: bool = False) -> Tensor:
show: Callable[[t.Any], None] = lambda *args, **kwargs: print(*args, **kwargs) if verbose else None
show: Callable[[t.Any], None] = lambda *args, **kwargs: (
print(*args, **kwargs) if verbose else None
)
if isinstance(args, Tensor):
input = args
elif isinstance(args, tuple):
Expand Down Expand Up @@ -150,10 +177,16 @@ def forward(self, args: Tensor | tuple, verbose: bool = False) -> Tensor:
s_inhibition = self.hook_s_inhibition(s_inhibition)
show(f"s_inhibition: {s_inhibition}")
out = self.name_mover_head(input, s_inhibition)
if self.cfg.return_one_hot:
out = t.nn.functional.one_hot(
t.argmax(out, dim=-1), num_classes=self.d_vocab
).float()
assert out.shape == input.shape + (self.d_vocab,)
out = self.hook_name_mover(out)
show(f"out: {t.argmax(out, dim=-1)}")
if not batched:
out = out[0]
return out


# %%
4 changes: 2 additions & 2 deletions iit/tasks/ioi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def make_ioi_dataset_and_hl(
)

ioi_names = t.tensor(
list(set([ioi_dataset_tl[i]["IO"].item() for i in range(len(ioi_dataset_tl))]))
).to(device)
[ll_model.tokenizer.encode(" " + name) for name in ioi_dataset_tl.names]
).flatten()
hl_model = IOI_HL(d_vocab=ll_model.cfg.d_vocab_out, names=ioi_names, device=device)

ioi_dataset = IOIDatasetWrapper(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ioi_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_name_mover_head() -> None:


def test_ioi_hl() -> None:
a = IOI_HL(d_vocab=21, names=IOI_TEST_NAMES)(
a = IOI_HL(d_vocab=21, names=IOI_TEST_NAMES, return_one_hot=False)(
(t.tensor([[3, 10, 4, 10, 5, 9, 2, 6, 5]]), None, None)
)
assert nonzero_values(a[0]).equal(
Expand Down

0 comments on commit e4c4448

Please sign in to comment.