From b067b07d370c75dc426feac7ff7cda8c244f8836 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Wed, 24 Jul 2024 23:29:26 +0800 Subject: [PATCH 1/2] fix: typo error --- src/lm_saes/sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 416ae46d..b86569b7 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -515,7 +515,7 @@ def transform_to_unit_decoder_norm(self): decoder_norm = self.decoder_norm() # (d_sae,) self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None] - self.decoder.weight.data = self.decoder.weight.data.T / decoder_norm + self.decoder.weight.data = self.decoder.weight.data / decoder_norm self.encoder.bias.data = self.encoder.bias.data * decoder_norm From 173f43489324db19bbe87d4b817fc35795a8f728 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Thu, 1 Aug 2024 00:50:12 +0800 Subject: [PATCH 2/2] fix(sae): transform decoder_norm and encoder_norm to dtensor under tensor parallel settings --- src/lm_saes/sae.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 278211e0..6c5abae0 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -1,3 +1,4 @@ +from builtins import print from importlib.metadata import version import os from typing import Dict, Literal, Union, overload, List @@ -117,7 +118,7 @@ def initialize_parameters(self): if self.cfg.init_encoder_with_decoder_transpose: self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous() else: - self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm) + self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm, during_init=True) def train_base_parameters(self): """Set the base parameters to be trained.""" @@ -481,6 +482,13 @@ def set_decoder_norm_to_fixed_norm( decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init) if force_exact is None: force_exact = self.cfg.decoder_exactly_fixed_norm + + + if self.cfg.tp_size > 1 and not during_init: + decoder_norm = distribute_tensor( + decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] + ) + if force_exact: self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm else: @@ -490,7 +498,7 @@ def set_decoder_norm_to_fixed_norm( ) @torch.no_grad() - def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): + def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0, during_init: bool = False): if self.cfg.use_glu_encoder: raise NotImplementedError("GLU encoder not supported") if value is None: @@ -498,7 +506,11 @@ def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): f"Encoder norm is not set to a fixed value, using random initialization." ) return - encoder_norm = self.encoder_norm(keepdim=True) + encoder_norm = self.encoder_norm(keepdim=True, during_init=during_init) + if self.cfg.tp_size > 1 and not during_init: + encoder_norm = distribute_tensor( + encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] + ) self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm @torch.no_grad() @@ -515,10 +527,25 @@ def transform_to_unit_decoder_norm(self): raise NotImplementedError("GLU encoder not supported") decoder_norm = self.decoder_norm() # (d_sae,) - self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None] - self.decoder.weight.data = self.decoder.weight.data / decoder_norm - - self.encoder.bias.data = self.encoder.bias.data * decoder_norm + if self.cfg.tp_size > 1: + decoder_norm_en = distribute_tensor( + decoder_norm[:, None], device_mesh=self.device_mesh["tp"], placements=[Replicate()] + ) + decoder_norm_de = distribute_tensor( + decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] + ) + dencoder_norm_bias = distribute_tensor( + decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] + ) + else: + decoder_norm_en = decoder_norm[:, None] + decoder_norm_de = decoder_norm + dencoder_norm_bias = decoder_norm + + self.encoder.weight.data = self.encoder.weight.data * decoder_norm_en + self.decoder.weight.data = self.decoder.weight.data / decoder_norm_de + + self.encoder.bias.data = self.encoder.bias.data * dencoder_norm_bias @torch.no_grad() def remove_gradient_parallel_to_decoder_directions(self):