From caf1ff6bb587b3628d13df397993650c24deea52 Mon Sep 17 00:00:00 2001 From: Ivan Arcuschin Date: Sun, 15 Sep 2024 23:26:08 +0100 Subject: [PATCH 1/2] Remove hardcoded entity for iit training --- iit/model_pairs/base_model_pair.py | 7 +++---- iit/model_pairs/probed_sequential_pair.py | 3 +-- iit/utils/config.py | 3 +-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 1d9428a..78434b5 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -12,7 +12,6 @@ import wandb # type: ignore from iit.model_pairs.ll_model import LLModel from iit.utils.nodes import HLNode, LLNode -from iit.utils.config import WANDB_ENTITY from iit.utils.correspondence import Correspondence from iit.utils.iit_dataset import IITDataset from iit.utils.index import Ix, TorchIndex @@ -223,7 +222,8 @@ def train( test_set: IITDataset, epochs: int = 1000, use_wandb: bool = False, - wandb_name_suffix: str = "", + wandb_project: str = "iit", + wandb_name: str = "", ) -> None: training_args = self.training_args print(f"{training_args=}") @@ -287,8 +287,7 @@ def linear_lr(step: int) -> float: print("No LR scheduler set up") if use_wandb and not wandb.run: - wandb.init(project="iit", name=wandb_name_suffix, - entity=WANDB_ENTITY) + wandb.init(project=wandb_project, name=wandb_name) if use_wandb: wandb.config.update(training_args) diff --git a/iit/model_pairs/probed_sequential_pair.py b/iit/model_pairs/probed_sequential_pair.py index 255ad63..a5c4f6a 100644 --- a/iit/model_pairs/probed_sequential_pair.py +++ b/iit/model_pairs/probed_sequential_pair.py @@ -8,7 +8,6 @@ from transformer_lens.hook_points import HookedRootModule #type: ignore from iit.model_pairs.iit_model_pair import IITModelPair -from iit.utils.config import WANDB_ENTITY from iit.utils.probes import construct_probes #type: ignore from iit.utils.correspondence import Correspondence from iit.utils.iit_dataset import IITDataset @@ -112,7 +111,7 @@ def train( loss_fn = t.nn.CrossEntropyLoss() if use_wandb and not wandb.run: - wandb.init(project="iit", entity=WANDB_ENTITY) + wandb.init(project="iit") if use_wandb: wandb.config.update(training_args) diff --git a/iit/utils/config.py b/iit/utils/config.py index e643053..a0ee7b7 100644 --- a/iit/utils/config.py +++ b/iit/utils/config.py @@ -1,4 +1,3 @@ import torch -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -WANDB_ENTITY = "cybershiptrooper" #TODO: This should be editable by the user at runtime \ No newline at end of file +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \ No newline at end of file From 0f3369f1cc020b7ff2b6ddeda38e56b0305d5440 Mon Sep 17 00:00:00 2001 From: Ivan Arcuschin Date: Sun, 15 Sep 2024 23:44:27 +0100 Subject: [PATCH 2/2] Minor fix to previous commit --- iit/model_pairs/probed_sequential_pair.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/iit/model_pairs/probed_sequential_pair.py b/iit/model_pairs/probed_sequential_pair.py index a5c4f6a..74ec93d 100644 --- a/iit/model_pairs/probed_sequential_pair.py +++ b/iit/model_pairs/probed_sequential_pair.py @@ -83,7 +83,8 @@ def train( test_set: IITDataset, epochs: int = 1000, use_wandb: bool = False, - wandb_name_suffix: str = "", + wandb_project: str = "iit", + wandb_name: str = "", optimizer_kwargs: dict = {}, ) -> None: training_args = self.training_args @@ -111,7 +112,7 @@ def train( loss_fn = t.nn.CrossEntropyLoss() if use_wandb and not wandb.run: - wandb.init(project="iit") + wandb.init(project=wandb_project, name=wandb_name) if use_wandb: wandb.config.update(training_args)