Skip to content

Commit

Permalink
make some changes to ioi
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Sep 4, 2024
1 parent 32c8e6e commit 73c1663
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions iit/tasks/ioi/ioi_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ class IOI_HL(HookedRootModule, HLModel):
- S-inhibition heads: Inhibit attention of Name Mover Heads to S1 and S2 tokens
- Name mover heads: Copy all previous names in the sentence
"""
def __init__(self, d_vocab: int, names: Tensor):
def __init__(
self,
d_vocab: int,
names: Tensor,
device: t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
):
super().__init__()
self.all_nodes_hook = HookPoint()
self.duplicate_head = DuplicateHead()
Expand All @@ -104,7 +109,7 @@ def __init__(self, d_vocab: int, names: Tensor):
self.cfg = Namespace(
d_vocab=d_vocab,
d_vocab_out=d_vocab,
device=t.device("cuda") if t.cuda.is_available() else t.device("cpu")
device=device
)
self.setup()

Expand Down
2 changes: 1 addition & 1 deletion iit/tasks/ioi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def make_ioi_dataset_and_hl(
ioi_names = t.tensor(
list(set([ioi_dataset_tl[i]["IO"].item() for i in range(len(ioi_dataset_tl))]))
).to(device)
hl_model = IOI_HL(d_vocab=ll_model.cfg.d_vocab_out, names=ioi_names).to(device)
hl_model = IOI_HL(d_vocab=ll_model.cfg.d_vocab_out, names=ioi_names, device=device)

ioi_dataset = IOIDatasetWrapper(
num_samples=num_samples,
Expand Down

0 comments on commit 73c1663

Please sign in to comment.