From 2320def2dcb55e718341153d4fdbf716b8754485 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Tue, 3 Dec 2024 08:54:44 +0800 Subject: [PATCH] fix: generate activations --- TransformerLens | 2 +- src/lm_saes/activation/activation_dataset.py | 5 +- src/lm_saes/activation/activation_store.py | 19 ++- src/lm_saes/activation/token_source.py | 14 +- src/lm_saes/entrypoint.py | 7 + src/lm_saes/runner.py | 136 ++++++------------- 6 files changed, 86 insertions(+), 97 deletions(-) diff --git a/TransformerLens b/TransformerLens index 39348462..ebf1e060 160000 --- a/TransformerLens +++ b/TransformerLens @@ -1 +1 @@ -Subproject commit 39348462033b7c7dcd80d93cefac8f0b9bfc4382 +Subproject commit ebf1e060e017460e90ecf2e482d196cf936b5fc1 diff --git a/src/lm_saes/activation/activation_dataset.py b/src/lm_saes/activation/activation_dataset.py index 26389622..222093df 100644 --- a/src/lm_saes/activation/activation_dataset.py +++ b/src/lm_saes/activation/activation_dataset.py @@ -5,9 +5,7 @@ from tqdm.auto import tqdm from transformer_lens import HookedTransformer -from ..config import ( - ActivationGenerationConfig, -) +from ..config import ActivationGenerationConfig from ..utils.misc import is_master, print_once from .activation_store import ActivationStore from .token_source import TokenSource @@ -90,6 +88,7 @@ def make_activation_dataset(model: HookedTransformer, cfg: ActivationGenerationC pbar = tqdm( total=total_generating_tokens, desc=f"Activation dataset Rank {dist.get_rank()}" if dist.is_initialized() else "Activation dataset", + smoothing=0.001, ) while n_tokens < total_generating_tokens: diff --git a/src/lm_saes/activation/activation_store.py b/src/lm_saes/activation/activation_store.py index 4ee1aee9..c637ab12 100644 --- a/src/lm_saes/activation/activation_store.py +++ b/src/lm_saes/activation/activation_store.py @@ -6,6 +6,7 @@ import torch.distributed._functional_collectives as funcol import torch.utils.data from torch.distributed.device_mesh import init_device_mesh +from tqdm import tqdm from transformer_lens import HookedTransformer from ..config import ActivationStoreConfig @@ -26,7 +27,10 @@ def __init__(self, act_source: ActivationSource, cfg: ActivationStoreConfig): self.tp_size = cfg.tp_size self._store: Dict[str, torch.Tensor] = {} self._all_gather_buffer: Dict[str, torch.Tensor] = {} - self.device_mesh = init_device_mesh("cuda", (self.ddp_size, self.tp_size), mesh_dim_names=("ddp", "tp")) + if self.tp_size > 1 or self.ddp_size > 1: + self.device_mesh = init_device_mesh("cuda", (self.ddp_size, self.tp_size), mesh_dim_names=("ddp", "tp")) + else: + self.device_mesh = None def initialize(self): self.refill() @@ -41,6 +45,14 @@ def shuffle(self): self._store[k] = self._store[k][perm] def refill(self): + pbar = tqdm( + total=self.buffer_size, + desc="Refilling activation store", + smoothing=0, + leave=False, + initial=self.__len__(), + ) + n_seqs = 0 while self.__len__() < self.buffer_size: new_act = self.act_source.next() if new_act is None: @@ -53,6 +65,10 @@ def refill(self): self._store[k] = torch.cat([self._store[k], v], dim=0) # Check if all activations have the same size assert len(set(v.size(0) for v in self._store.values())) == 1 + n_seqs += 1 + pbar.update(next(iter(new_act.values())).size(0)) + pbar.set_postfix({"Sequences": n_seqs}) + pbar.close() def __len__(self): if len(self._store) == 0: @@ -75,6 +91,7 @@ def next(self, batch_size) -> Dict[str, torch.Tensor] | None: if dist.is_initialized(): # Wait for all processes to refill the store dist.barrier() if self.tp_size > 1: + assert self.device_mesh is not None, "Device mesh not initialized" for k, v in self._store.items(): if k not in self._all_gather_buffer: self._all_gather_buffer[k] = torch.empty(size=(0,), dtype=v.dtype, device=self.device) diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index a018f6b4..35d7506d 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -43,7 +43,18 @@ def __init__( def fill_with_one_batch(self, batch: dict[str, Any], pack: bool, prepend_bos: bool) -> None: if self.is_dataset_tokenized: - tokens: torch.Tensor = batch["tokens"].to(self.device) + if isinstance(batch["input_ids"], torch.Tensor): + assert not batch["input_ids"].dtype.is_floating_point, "input_ids must be a tensor of integers" + tokens = batch["input_ids"].to(self.device) + else: + assert isinstance(batch["input_ids"], list), "input_ids must be a list or a tensor" + print("Batch size:", len(batch["input_ids"]), "Type:", type(batch["input_ids"])) + print("Sequence length:", len(batch["input_ids"][0]), "Type:", type(batch["input_ids"][0])) + # Check if all sequences in the batch have the same length + assert all( + len(seq) == len(batch["input_ids"][0]) for seq in batch["input_ids"] + ), "All sequences must have the same length" + tokens = torch.tensor(batch["input_ids"], dtype=torch.long, device=self.device) else: tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device) if pack: @@ -124,6 +135,7 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig): shard = dataset.shard(num_shards=dist.get_world_size(), index=shard_id, contiguous=True) else: shard = dataset + shard = shard.with_format("torch") dataloader = DataLoader( dataset=cast(Dataset[dict[str, Any]], shard), batch_size=cfg.store_batch_size, pin_memory=True diff --git a/src/lm_saes/entrypoint.py b/src/lm_saes/entrypoint.py index f0961a1b..f952a9af 100644 --- a/src/lm_saes/entrypoint.py +++ b/src/lm_saes/entrypoint.py @@ -9,6 +9,7 @@ class SupportedRunner(Enum): EVAL = "eval" ANALYZE = "analyze" PRUNE = "prune" + GENERATE_ACTIVATIONS = "gen-activations" def __str__(self): return self.value @@ -97,6 +98,12 @@ def entrypoint(): config = LanguageModelSAEPruningConfig.from_flattened(config) language_model_sae_prune_runner(config) + elif args.runner == SupportedRunner.GENERATE_ACTIVATIONS: + from lm_saes.config import ActivationGenerationConfig + from lm_saes.runner import activation_generation_runner + + config = ActivationGenerationConfig.from_flattened(config) + activation_generation_runner(config) else: raise ValueError(f"Unsupported runner: {args.runner}.") diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 4221ad81..a6295431 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -11,7 +11,12 @@ parallelize_module, ) from transformer_lens import HookedTransformer -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + ChameleonForConditionalGeneration, + PreTrainedModel, +) from .activation.activation_dataset import make_activation_dataset from .activation.activation_source import CachedActivationSource @@ -21,6 +26,7 @@ from .config import ( ActivationGenerationConfig, FeaturesDecoderConfig, + LanguageModelConfig, LanguageModelCrossCoderTrainingConfig, LanguageModelSAEAnalysisConfig, LanguageModelSAEPruningConfig, @@ -36,36 +42,47 @@ from .utils.misc import is_master +def get_model(cfg: LanguageModelConfig): + if "chameleon" in cfg.model_name: + hf_model = ChameleonForConditionalGeneration.from_pretrained( + (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), + cache_dir=cfg.cache_dir, + local_files_only=cfg.local_files_only, + torch_dtype=cfg.dtype, + ) + else: + hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), + cache_dir=cfg.cache_dir, + local_files_only=cfg.local_files_only, + torch_dtype=cfg.dtype, + ) + hf_tokenizer = AutoTokenizer.from_pretrained( + (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) + model = HookedTransformer.from_pretrained_no_processing( + cfg.model_name, + use_flash_attn=cfg.use_flash_attn, + device=cfg.device, + cache_dir=cfg.cache_dir, + hf_model=hf_model, + tokenizer=hf_tokenizer, + dtype=cfg.dtype, + ) + model.eval() + return model + + def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): if cfg.act_store.use_cached_activations: activation_source = CachedActivationSource(cfg.act_store) activation_store = ActivationStore(act_source=activation_source, cfg=cfg.act_store) model = None else: - hf_model = AutoModelForCausalLM.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - cache_dir=cfg.lm.cache_dir, - local_files_only=cfg.lm.local_files_only, - torch_dtype=cfg.lm.dtype, - ) - hf_tokenizer = AutoTokenizer.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - trust_remote_code=True, - use_fast=True, - add_bos_token=True, - ) - - model = HookedTransformer.from_pretrained_no_processing( - cfg.lm.model_name, - use_flash_attn=cfg.lm.use_flash_attn, - device=cfg.lm.device, - cache_dir=cfg.lm.cache_dir, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=cfg.lm.dtype, - ) - model.offload_params_after(cfg.act_store.hook_points[-1], torch.tensor([[0]], device=cfg.lm.device)) - model.eval() + model = get_model(cfg.lm) activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) if not cfg.finetuning and ( @@ -182,28 +199,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_path)) cfg.lm.save_lm_config(os.path.join(cfg.exp_result_path)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - hf_model = AutoModelForCausalLM.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - cache_dir=cfg.lm.cache_dir, - local_files_only=cfg.lm.local_files_only, - torch_dtype=cfg.lm.dtype, - ) - hf_tokenizer = AutoTokenizer.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - trust_remote_code=True, - use_fast=True, - add_bos_token=True, - ) - model = HookedTransformer.from_pretrained_no_processing( - cfg.lm.model_name, - use_flash_attn=cfg.lm.use_flash_attn, - device=cfg.lm.device, - cache_dir=cfg.lm.cache_dir, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=cfg.lm.dtype, - ) - model.eval() + model = get_model(cfg.lm) activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { @@ -243,29 +239,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - hf_model = AutoModelForCausalLM.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - cache_dir=cfg.lm.cache_dir, - local_files_only=cfg.lm.local_files_only, - ) - - hf_tokenizer = AutoTokenizer.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - trust_remote_code=True, - use_fast=True, - add_bos_token=True, - ) - model = HookedTransformer.from_pretrained_no_processing( - cfg.lm.model_name, - use_flash_attn=cfg.lm.use_flash_attn, - device=cfg.lm.device, - cache_dir=cfg.lm.cache_dir, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=cfg.lm.dtype, - ) - - model.eval() + model = get_model(cfg.lm) activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) if cfg.wandb.log_to_wandb and is_master(): @@ -301,27 +275,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): def activation_generation_runner(cfg: ActivationGenerationConfig): - hf_model = AutoModelForCausalLM.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - cache_dir=cfg.lm.cache_dir, - local_files_only=cfg.lm.local_files_only, - ) - hf_tokenizer = AutoTokenizer.from_pretrained( - (cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path), - trust_remote_code=True, - use_fast=True, - add_bos_token=True, - ) - model = HookedTransformer.from_pretrained_no_processing( - cfg.lm.model_name, - use_flash_attn=cfg.lm.use_flash_attn, - device=cfg.lm.device, - cache_dir=cfg.lm.cache_dir, - hf_model=hf_model, - tokenizer=hf_tokenizer, - dtype=cfg.lm.dtype, - ) - model.eval() + model = get_model(cfg.lm) make_activation_dataset(model, cfg)