Skip to content

Commit

Permalink
Merge pull request #40 from OpenMOSS/zfhe
Browse files Browse the repository at this point in the history
Zfhe
  • Loading branch information
Hzfinfdu authored Jul 30, 2024
2 parents 98e74ba + 0d333e1 commit 99e42d2
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig):
if dist.is_initialized():
shard_id = dist.get_rank()
shard = dataset.shard(
num_shards=dist.get_world_size(), index=shard_id
num_shards=dist.get_world_size(), index=shard_id, contiguous=True
)
else:
shard = dataset


dataloader = DataLoader(shard, batch_size=cfg.store_batch_size)
dataloader = DataLoader(shard, batch_size=cfg.store_batch_size, num_workers=4, prefetch_factor=4, pin_memory=True)
return dataloader

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig):
lr_warm_up_steps: int | float = 0.1
lr_cool_down_steps: int | float = 0.1
train_batch_size: int = 4096
clip_grad_value: float = 0.0
clip_grad_norm: float = 0.0
remove_gradient_parallel_to_decoder_directions: bool = False

finetuning: bool = False
Expand Down
44 changes: 27 additions & 17 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from lm_saes.utils.misc import is_master

from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
loss_parallel,
)
from torch.distributed._tensor import (
DTensor,
Shard,
Replicate,
distribute_module,
distribute_tensor,
)


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
if is_master():
Expand Down Expand Up @@ -77,24 +90,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
assert not cfg.finetuning
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and is_master():
wandb_config: dict = {
Expand Down Expand Up @@ -304,6 +300,20 @@ def activation_generation_runner(cfg: ActivationGenerationConfig):
def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig):
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.sae.tp_size > 1:
plan = {
"encoder": ColwiseParallel(output_layouts=Replicate()),
}
if cfg.sae.use_glu_encoder:
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore
sae.parallelize_plan = plan

sae.decoder.weight = None # type: ignore[assignment]
torch.cuda.empty_cache()



hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down
18 changes: 10 additions & 8 deletions src/lm_saes/sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def train_sae(
if cfg.finetuning:
loss = loss_data["l_rec"].mean()
loss.backward()
if cfg.clip_grad_value > 0:
torch.nn.utils.clip_grad_value_(sae.parameters(), cfg.clip_grad_value)
grad_norm = torch.tensor([0.0], device=cfg.sae.device)
if cfg.clip_grad_norm > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm)
if cfg.remove_gradient_parallel_to_decoder_directions:
sae.remove_gradient_parallel_to_decoder_directions()
optimizer.step()
Expand All @@ -171,13 +172,13 @@ def train_sae(
if cfg.wandb.log_to_wandb and (is_master()):
feature_sparsity = act_freq_scores / n_frac_active_tokens
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
wandb_histogram = wandb.Histogram(
log_feature_sparsity.detach().cpu().float().numpy()
)
# wandb_histogram = wandb.Histogram(
# log_feature_sparsity.detach().cpu().float().numpy()
# )
wandb.log(
{
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
"plots/feature_density_line_chart": wandb_histogram,
# "plots/feature_density_line_chart": wandb_histogram,
"sparsity/below_1e-5": (feature_sparsity < 1e-5)
.sum()
.item(),
Expand Down Expand Up @@ -285,8 +286,9 @@ def train_sae(
# norm
"metrics/decoder_norm": decoder_norm.item(),
"metrics/encoder_norm": encoder_norm.item(),
"metrics/decoder_bias_mean": sae.decoder.bias.mean().item() if sae.cfg.use_decoder_bias else 0,
"metrics/enocder_bias_mean": sae.encoder.bias.mean().item(),
"metrics/decoder_bias_norm": sae.decoder.bias.norm().item() if sae.cfg.use_decoder_bias else 0,
"metrics/encoder_bias_norm": sae.encoder.bias.norm().item(),
"metrics/gradients_norm": grad_norm.item(),
# sparsity
"sparsity/l1_coefficient": sae.current_l1_coefficient,
"sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(),
Expand Down

0 comments on commit 99e42d2

Please sign in to comment.