From aea84c776dc6cc16a7616e15a6b4d9e92c56ef23 Mon Sep 17 00:00:00 2001 From: Rohan Gupta Date: Sun, 8 Sep 2024 00:53:59 +0530 Subject: [PATCH] device bugfix --- iit/tasks/ioi/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iit/tasks/ioi/utils.py b/iit/tasks/ioi/utils.py index 461b26b..5266b43 100644 --- a/iit/tasks/ioi/utils.py +++ b/iit/tasks/ioi/utils.py @@ -29,7 +29,7 @@ def make_ioi_dataset_and_hl( ioi_names = t.tensor( [ll_model.tokenizer.encode(" " + name) for name in ioi_dataset_tl.names] - ).flatten() + ).flatten().to(device) hl_model = IOI_HL(d_vocab=ll_model.cfg.d_vocab_out, names=ioi_names, device=device) ioi_dataset = IOIDatasetWrapper(