From c73c64a0fe6825fab06b890a5005e67820feb54f Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Wed, 26 Jun 2024 17:21:06 +0800 Subject: [PATCH] fix: resolve DDP-related synchronization bug --- src/lm_saes/runner.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index e5f02247..51550377 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -28,11 +28,12 @@ from lm_saes.sae_training import prune_sae, train_sae from lm_saes.analysis.sample_feature_activations import sample_feature_activations from lm_saes.analysis.features_to_logits import features_to_logits - +from torch.nn.parallel import DistributedDataParallel as DDP def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if (not cfg.use_ddp) or cfg.rank == 0: + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) if cfg.finetuning: @@ -68,7 +69,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - + if cfg.use_ddp: + _ = DDP(model, device_ids=[cfg.rank]) + _ = DDP(sae, device_ids=[cfg.rank]) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)