From 3ab717dcfebaf8efc9fc2256d7484bd05705f1ef Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Tue, 30 Jul 2024 10:31:33 +0800 Subject: [PATCH] fix(training): stable GPU usage --- src/lm_saes/activation/token_source.py | 4 ++-- src/lm_saes/runner.py | 2 ++ src/lm_saes/sae.py | 2 +- src/lm_saes/sae_training.py | 8 ++++---- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index 3b15c559..65e9681b 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -116,13 +116,13 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig): if dist.is_initialized(): shard_id = dist.get_rank() shard = dataset.shard( - num_shards=dist.get_world_size(), index=shard_id + num_shards=dist.get_world_size(), index=shard_id, contiguous=True ) else: shard = dataset - dataloader = DataLoader(shard, batch_size=cfg.store_batch_size) + dataloader = DataLoader(shard, batch_size=cfg.store_batch_size, num_workers=4, prefetch_factor=4, pin_memory=True) return dataloader @staticmethod diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index aaab27cc..ad45ddcb 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -90,6 +90,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + + if ( cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None or cfg.sae.init_decoder_norm is None diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index d2384227..ddec60e3 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -262,7 +262,7 @@ def encode( if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: x = ( - x - self.decoder.bias.to_local() # type: ignore + x - self.decoder.bias.data.to_local() # type: ignore if self.cfg.tp_size > 1 else x - self.decoder.bias ) diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index a7d85546..b5c9d2ec 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -172,13 +172,13 @@ def train_sae( if cfg.wandb.log_to_wandb and (is_master()): feature_sparsity = act_freq_scores / n_frac_active_tokens log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) - wandb_histogram = wandb.Histogram( - log_feature_sparsity.detach().cpu().float().numpy() - ) + # wandb_histogram = wandb.Histogram( + # log_feature_sparsity.detach().cpu().float().numpy() + # ) wandb.log( { "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(), - "plots/feature_density_line_chart": wandb_histogram, + # "plots/feature_density_line_chart": wandb_histogram, "sparsity/below_1e-5": (feature_sparsity < 1e-5) .sum() .item(),