diff --git a/iit/model_pairs/base_model_pair.py b/iit/model_pairs/base_model_pair.py index 1d9428a..78434b5 100644 --- a/iit/model_pairs/base_model_pair.py +++ b/iit/model_pairs/base_model_pair.py @@ -12,7 +12,6 @@ import wandb # type: ignore from iit.model_pairs.ll_model import LLModel from iit.utils.nodes import HLNode, LLNode -from iit.utils.config import WANDB_ENTITY from iit.utils.correspondence import Correspondence from iit.utils.iit_dataset import IITDataset from iit.utils.index import Ix, TorchIndex @@ -223,7 +222,8 @@ def train( test_set: IITDataset, epochs: int = 1000, use_wandb: bool = False, - wandb_name_suffix: str = "", + wandb_project: str = "iit", + wandb_name: str = "", ) -> None: training_args = self.training_args print(f"{training_args=}") @@ -287,8 +287,7 @@ def linear_lr(step: int) -> float: print("No LR scheduler set up") if use_wandb and not wandb.run: - wandb.init(project="iit", name=wandb_name_suffix, - entity=WANDB_ENTITY) + wandb.init(project=wandb_project, name=wandb_name) if use_wandb: wandb.config.update(training_args) diff --git a/iit/model_pairs/probed_sequential_pair.py b/iit/model_pairs/probed_sequential_pair.py index 255ad63..74ec93d 100644 --- a/iit/model_pairs/probed_sequential_pair.py +++ b/iit/model_pairs/probed_sequential_pair.py @@ -8,7 +8,6 @@ from transformer_lens.hook_points import HookedRootModule #type: ignore from iit.model_pairs.iit_model_pair import IITModelPair -from iit.utils.config import WANDB_ENTITY from iit.utils.probes import construct_probes #type: ignore from iit.utils.correspondence import Correspondence from iit.utils.iit_dataset import IITDataset @@ -84,7 +83,8 @@ def train( test_set: IITDataset, epochs: int = 1000, use_wandb: bool = False, - wandb_name_suffix: str = "", + wandb_project: str = "iit", + wandb_name: str = "", optimizer_kwargs: dict = {}, ) -> None: training_args = self.training_args @@ -112,7 +112,7 @@ def train( loss_fn = t.nn.CrossEntropyLoss() if use_wandb and not wandb.run: - wandb.init(project="iit", entity=WANDB_ENTITY) + wandb.init(project=wandb_project, name=wandb_name) if use_wandb: wandb.config.update(training_args) diff --git a/iit/utils/config.py b/iit/utils/config.py index e643053..a0ee7b7 100644 --- a/iit/utils/config.py +++ b/iit/utils/config.py @@ -1,4 +1,3 @@ import torch -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -WANDB_ENTITY = "cybershiptrooper" #TODO: This should be editable by the user at runtime \ No newline at end of file +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \ No newline at end of file