Skip to content

Commit

Permalink
feat: use wandb auto resuming feature
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 23, 2024
1 parent bb49d16 commit ad3a344
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
7 changes: 4 additions & 3 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,11 @@ def train(config: Config):
sharding_strategy = ShardingStrategy.NO_SHARD
log("Hivemind is used, ShardingStrategy.NO_SHARD is used")

resume_from_ckpt, resume_path = get_resume_info(config.ckpt)

if rank == 0:
logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger
metric_logger = logger_cls(project=config.project, config=config.model_dump())
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=resume_from_ckpt)

if config.hv is not None:
log("hivemind diloco enabled")
Expand Down Expand Up @@ -257,8 +259,6 @@ def scheduler_fn(opt):
num_training_steps=config.total_steps,
)

resume_from_ckpt, resume_path = get_resume_info(config.ckpt)

if config.hv is not None:
if resume_from_ckpt:
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
Expand Down Expand Up @@ -510,6 +510,7 @@ def scheduler_fn(opt):

if config.max_steps is not None and real_step >= config.max_steps:
break

log("Training completed.")
if rank == 0:
metric_logger.finish()
Expand Down
8 changes: 5 additions & 3 deletions open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def finish(self): ...


class WandbLogger:
def __init__(self, project, config):
wandb.init(project=project, config=config)
def __init__(self, project, config, resume: bool):
wandb.init(
project=project, config=config, resume="auto" if resume else None
) # make wandb reuse the same run id if possible

def log(self, metrics: dict[str, Any]):
wandb.log(metrics)
Expand All @@ -187,7 +189,7 @@ def finish(self):


class DummyLogger:
def __init__(self, project, config):
def __init__(self, project, config, *args, **kwargs):
self.project = project
self.config = config
open(project, "a").close() # Create an empty file at the project path
Expand Down

0 comments on commit ad3a344

Please sign in to comment.