Skip to content

Commit

Permalink
misc: minor bug fix & support loading configs with path in the MongoDB
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Aug 12, 2024
1 parent acfac55 commit 07a9c0e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
15 changes: 9 additions & 6 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@


def get_model(dictionary_name: str) -> HookedTransformer:
cfg = LanguageModelConfig.from_pretrained_sae(f"{result_dir}/{dictionary_name}")
path = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)['path'] or f"{result_dir}/{dictionary_name}"
cfg = LanguageModelConfig.from_pretrained_sae(path)
if (cfg.model_name, cfg.model_from_pretrained_path) not in lm_cache:
hf_model = AutoModelForCausalLM.from_pretrained(
(
Expand Down Expand Up @@ -81,8 +82,9 @@ def get_model(dictionary_name: str) -> HookedTransformer:


def get_sae(dictionary_name: str) -> SparseAutoEncoder:
path = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)['path'] or f"{result_dir}/{dictionary_name}"
if dictionary_name not in sae_cache:
sae = SparseAutoEncoder.from_pretrained(f"{result_dir}/{dictionary_name}", device=device)
sae = SparseAutoEncoder.from_pretrained(path, device=device)
sae.eval()
sae_cache[dictionary_name] = sae
return sae_cache[dictionary_name]
Expand Down Expand Up @@ -398,6 +400,7 @@ def feature_interpretation(
custom_interpretation: str | None = None,
):
model = get_model(dictionary_name)
path = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)['path'] or f"{result_dir}/{dictionary_name}"
if type == "custom":
interpretation = {
"text": custom_interpretation,
Expand All @@ -411,8 +414,8 @@ def feature_interpretation(
elif type == "auto":
cfg = AutoInterpConfig(
**{
**SAEConfig.from_pretrained(f"{result_dir}/{dictionary_name}").to_dict(),
**LanguageModelConfig.from_pretrained_sae(f"{result_dir}/{dictionary_name}").to_dict(),
"sae": SAEConfig.from_pretrained(path).to_dict(),
"lm": LanguageModelConfig.from_pretrained_sae(path).to_dict(),
"openai_api_key": os.environ.get("OPENAI_API_KEY"),
"openai_base_url": os.environ.get("OPENAI_BASE_URL"),
}
Expand All @@ -427,8 +430,8 @@ def feature_interpretation(
elif type == "validate":
cfg = AutoInterpConfig(
**{
**SAEConfig.from_pretrained(f"{result_dir}/{dictionary_name}").to_dict(),
**LanguageModelConfig.from_pretrained_sae(f"{result_dir}/{dictionary_name}").to_dict(),
"sae": SAEConfig.from_pretrained(path).to_dict(),
"lm": LanguageModelConfig.from_pretrained_sae(path).to_dict(),
"openai_api_key": os.environ.get("OPENAI_API_KEY"),
"openai_base_url": os.environ.get("OPENAI_BASE_URL"),
}
Expand Down
11 changes: 11 additions & 0 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def update_feature(self, dictionary_name: str, feature_index: int, feature_data:
def list_dictionaries(self, dictionary_series: str | None = None):
return [d['name'] for d in self.dictionary_collection.find({'series': dictionary_series} if dictionary_series is not None else {})]

def get_dictionary(self, dictionary_name: str, dictionary_series: str | None = None):
dictionary = self.dictionary_collection.find_one({'name': dictionary_name, 'series': dictionary_series})
if dictionary is None:
return None
return {
'name': dictionary['name'],
'n_features': dictionary['n_features'],
'series': dictionary['series'],
'path': dictionary['path'] if 'path' in dictionary else None
}

def get_feature(self, dictionary_name: str, feature_index: int, dictionary_series: str | None = None):
dictionary = self.dictionary_collection.find_one({'name': dictionary_name, 'series': dictionary_series})
if dictionary is None:
Expand Down
5 changes: 3 additions & 2 deletions src/lm_saes/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def entrypoint():
else:
raise ValueError(f'Unsupported runner: {args.runner}.')

if dist.is_initialized():
dist.destroy_process_group()
if tp_size > 1 or ddp_size > 1:
if dist.is_initialized():
dist.destroy_process_group()

0 comments on commit 07a9c0e

Please sign in to comment.