Skip to content

Commit

Permalink
Merge pull request #19 from FlyingPumba/fix/hardcoded-wandb-entity
Browse files Browse the repository at this point in the history
Remove hardcoded entity for iit training
  • Loading branch information
cybershiptrooper authored Sep 16, 2024
2 parents 04c5c4c + 0f3369f commit f3c96e8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
7 changes: 3 additions & 4 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import wandb # type: ignore
from iit.model_pairs.ll_model import LLModel
from iit.utils.nodes import HLNode, LLNode
from iit.utils.config import WANDB_ENTITY
from iit.utils.correspondence import Correspondence
from iit.utils.iit_dataset import IITDataset
from iit.utils.index import Ix, TorchIndex
Expand Down Expand Up @@ -223,7 +222,8 @@ def train(
test_set: IITDataset,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
wandb_project: str = "iit",
wandb_name: str = "",
) -> None:
training_args = self.training_args
print(f"{training_args=}")
Expand Down Expand Up @@ -287,8 +287,7 @@ def linear_lr(step: int) -> float:
print("No LR scheduler set up")

if use_wandb and not wandb.run:
wandb.init(project="iit", name=wandb_name_suffix,
entity=WANDB_ENTITY)
wandb.init(project=wandb_project, name=wandb_name)

if use_wandb:
wandb.config.update(training_args)
Expand Down
6 changes: 3 additions & 3 deletions iit/model_pairs/probed_sequential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from transformer_lens.hook_points import HookedRootModule #type: ignore

from iit.model_pairs.iit_model_pair import IITModelPair
from iit.utils.config import WANDB_ENTITY
from iit.utils.probes import construct_probes #type: ignore
from iit.utils.correspondence import Correspondence
from iit.utils.iit_dataset import IITDataset
Expand Down Expand Up @@ -84,7 +83,8 @@ def train(
test_set: IITDataset,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
wandb_project: str = "iit",
wandb_name: str = "",
optimizer_kwargs: dict = {},
) -> None:
training_args = self.training_args
Expand Down Expand Up @@ -112,7 +112,7 @@ def train(
loss_fn = t.nn.CrossEntropyLoss()

if use_wandb and not wandb.run:
wandb.init(project="iit", entity=WANDB_ENTITY)
wandb.init(project=wandb_project, name=wandb_name)

if use_wandb:
wandb.config.update(training_args)
Expand Down
3 changes: 1 addition & 2 deletions iit/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
WANDB_ENTITY = "cybershiptrooper" #TODO: This should be editable by the user at runtime
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

0 comments on commit f3c96e8

Please sign in to comment.