diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index a5be1ed8..0e0ae403 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -253,7 +253,7 @@ def encode( label = x if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: - x = x - self.decoder.bias + x = x - self.decoder.bias.to_local() if self.cfg.tp_size > 1 else x - self.decoder.bias x = x * self.compute_norm_factor(x, hook_point="in")