diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index c13eb9af..e5f02247 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -119,11 +119,22 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): cache_dir=cfg.lm.cache_dir, local_files_only=cfg.lm.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) model.eval() @@ -239,11 +250,22 @@ def activation_generation_runner(cfg: ActivationGenerationConfig): cache_dir=cfg.lm.cache_dir, local_files_only=cfg.lm.local_files_only, ) + hf_tokenizer = AutoTokenizer.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + trust_remote_code=True, + use_fast=True, + add_bos_token=True, + ) model = HookedTransformer.from_pretrained( cfg.lm.model_name, device=cfg.lm.device, cache_dir=cfg.lm.cache_dir, hf_model=hf_model, + tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) model.eval()