Skip to content

Commit

Permalink
fix(training): stable GPU usage
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfinfdu committed Jul 30, 2024
1 parent e5ad29e commit 3ab717d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig):
if dist.is_initialized():
shard_id = dist.get_rank()
shard = dataset.shard(
num_shards=dist.get_world_size(), index=shard_id
num_shards=dist.get_world_size(), index=shard_id, contiguous=True
)
else:
shard = dataset


dataloader = DataLoader(shard, batch_size=cfg.store_batch_size)
dataloader = DataLoader(shard, batch_size=cfg.store_batch_size, num_workers=4, prefetch_factor=4, pin_memory=True)
return dataloader

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)



if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def encode(

if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
x = (
x - self.decoder.bias.to_local() # type: ignore
x - self.decoder.bias.data.to_local() # type: ignore
if self.cfg.tp_size > 1
else x - self.decoder.bias
)
Expand Down
8 changes: 4 additions & 4 deletions src/lm_saes/sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def train_sae(
if cfg.wandb.log_to_wandb and (is_master()):
feature_sparsity = act_freq_scores / n_frac_active_tokens
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
wandb_histogram = wandb.Histogram(
log_feature_sparsity.detach().cpu().float().numpy()
)
# wandb_histogram = wandb.Histogram(
# log_feature_sparsity.detach().cpu().float().numpy()
# )
wandb.log(
{
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
"plots/feature_density_line_chart": wandb_histogram,
# "plots/feature_density_line_chart": wandb_histogram,
"sparsity/below_1e-5": (feature_sparsity < 1e-5)
.sum()
.item(),
Expand Down

0 comments on commit 3ab717d

Please sign in to comment.