Skip to content

Commit

Permalink
refactor: make all types work with basedpyright
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 18, 2024
1 parent 16ca67a commit 8c61d1d
Show file tree
Hide file tree
Showing 20 changed files with 253 additions and 550 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
- name: Install dependencies
run: pdm install
- name: Type check
run: pdm run mypy .
run: pdm run basedpyright
- name: Unit tests
run: pdm run pytest ./tests

Expand Down
150 changes: 74 additions & 76 deletions pdm.lock

Large diffs are not rendered by default.

13 changes: 5 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,15 @@ lm-saes = "lm_saes.entrypoint:entrypoint"
[dependency-groups]
dev = [
"-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens",
"mypy>=1.13.0",
"pytest>=8.3.3",
"ipykernel>=6.29.5",
"nbformat>=5.10.4",
"kaleido==0.2.1",
"pre-commit>=4.0.1",
"ruff>=0.7.1",
"basedpyright>=1.21.0",
]

[tool.mypy]
check_untyped_defs=true
exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"]
ignore_missing_imports=true
allow_redefinition=true
implicit_optional=true

[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down Expand Up @@ -144,3 +137,7 @@ docstring-code-format = false
# enabled.
docstring-code-line-length = "dynamic"

[tool.pyright]
ignore = [".venv/", "examples", "TransformerLens", "tests", "exp"]
typeCheckingMode = "standard"
reportRedeclaration = false
89 changes: 57 additions & 32 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

def get_model(dictionary_name: str) -> HookedTransformer:
path = client.get_dictionary_path(dictionary_name, dictionary_series=dictionary_series)
if path == "":
if path is None:
path = f"{result_dir}/{dictionary_name}"
cfg = LanguageModelConfig.from_pretrained_sae(path)
if (cfg.model_name, cfg.model_from_pretrained_path) not in lm_cache:
Expand Down Expand Up @@ -71,10 +71,8 @@ def get_model(dictionary_name: str) -> HookedTransformer:


def get_sae(dictionary_name: str) -> SparseAutoEncoder:
path = (
client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)["path"]
or f"{result_dir}/{dictionary_name}"
)
dictionary = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)
path = dictionary["path"] if dictionary is not None else f"{result_dir}/{dictionary_name}"
if dictionary_name not in sae_cache:
sae = SparseAutoEncoder.from_pretrained(path)
sae.eval()
Expand Down Expand Up @@ -116,18 +114,17 @@ def list_dictionaries():
@app.get("/dictionaries/{dictionary_name}/features/{feature_index}")
def get_feature(dictionary_name: str, feature_index: str | int):
tokenizer = get_model(dictionary_name).tokenizer
if isinstance(feature_index, str):
if feature_index == "random":
feature = client.get_random_alive_feature(dictionary_name, dictionary_series=dictionary_series)
else:
try:
feature_index = int(feature_index)
except ValueError:
return Response(
content=f"Feature index {feature_index} is not a valid integer",
status_code=400,
)
if isinstance(feature_index, int):
if isinstance(feature_index, str) and feature_index != "random":
try:
feature_index = int(feature_index)
except ValueError:
return Response(
content=f"Feature index {feature_index} is not a valid integer",
status_code=400,
)
if feature_index == "random":
feature = client.get_random_alive_feature(dictionary_name, dictionary_series=dictionary_series)
else:
feature = client.get_feature(dictionary_name, feature_index, dictionary_series=dictionary_series)

if feature is None:
Expand All @@ -142,7 +139,8 @@ def get_feature(dictionary_name: str, feature_index: str | int):
{
"context": [
bytearray([byte_decoder[c] for c in t])
for t in tokenizer.convert_ids_to_tokens(analysis["contexts"][i])
# Method `convert_ids_to_tokens` should exist on GPT2Tokenizer and other BPE tokenizers.
for t in tokenizer.convert_ids_to_tokens(analysis["contexts"][i]) # type: ignore
],
"feature_acts": analysis["feature_acts"][i],
}
Expand Down Expand Up @@ -176,6 +174,8 @@ def get_feature(dictionary_name: str, feature_index: str | int):
+ ["#636EFA" for _ in range((len(logits_bin_edges) - 1) // 2)],
showlegend=False,
).to_plotly_json()
else:
logits_histogram = None

return Response(
content=msgpack.packb(
Expand Down Expand Up @@ -258,7 +258,9 @@ def feature_activation_custom_input(dictionary_name: str, feature_index: int, in
feature_acts = sae.encode(cache[sae.cfg.hook_point_in][0])
sample = {
"context": [
bytearray([byte_decoder[c] for c in t]) for t in model.tokenizer.convert_ids_to_tokens(input[0])
bytearray([byte_decoder[c] for c in t])
# Method `convert_ids_to_tokens` should exist on GPT2Tokenizer and other BPE tokenizers.
for t in model.tokenizer.convert_ids_to_tokens(input[0]) # type: ignore
],
"feature_acts": feature_acts[:, feature_index].tolist(),
}
Expand All @@ -274,6 +276,7 @@ def dictionary_custom_input(dictionary_name: str, input_text: str):
return Response(content=f"Dictionary {dictionary_name} not found", status_code=404)

max_feature_acts = client.get_max_feature_acts(dictionary_name, dictionary_series=dictionary_series)
assert max_feature_acts is not None, "Max feature acts not found"

model = get_model(dictionary_name)

Expand All @@ -288,7 +291,9 @@ def dictionary_custom_input(dictionary_name: str, input_text: str):
feature_acts = sae.encode(cache[sae.cfg.hook_point_in][0])
sample = {
"context": [
bytearray([byte_decoder[c] for c in t]) for t in model.tokenizer.convert_ids_to_tokens(input[0])
bytearray([byte_decoder[c] for c in t])
# Method `convert_ids_to_tokens` should exist on GPT2Tokenizer and other BPE tokenizers.
for t in model.tokenizer.convert_ids_to_tokens(input[0]) # type: ignore
],
"feature_acts_indices": [
feature_acts[i].nonzero(as_tuple=True)[0].tolist() for i in range(feature_acts.shape[0])
Expand Down Expand Up @@ -356,6 +361,10 @@ def model_generate(request: ModelGenerateRequest):
max_feature_acts = {
name: client.get_max_feature_acts(name, dictionary_series=dictionary_series) for _, name in saes
}
assert all(
max_feature_acts is not None for max_feature_acts in max_feature_acts.values()
), "Max feature acts not found"
max_feature_acts = cast(dict[str, dict[int, int]], max_feature_acts)
assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found"

def generate_steering_hook(steering: SteeringConfig):
Expand Down Expand Up @@ -389,11 +398,14 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint):
else torch.tensor([request.input_text], device=device)
)
if request.max_new_tokens > 0:
output = model.generate(
input,
max_new_tokens=request.max_new_tokens,
top_k=request.top_k,
top_p=request.top_p,
output = cast(
torch.Tensor,
model.generate(
input,
max_new_tokens=request.max_new_tokens,
top_k=request.top_k,
top_p=request.top_p,
),
)
input = output.clone()
name_filter = (
Expand All @@ -406,15 +418,18 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint):

result = {
"context": [
bytearray([byte_decoder[c] for c in t]) for t in model.tokenizer.convert_ids_to_tokens(input[0])
bytearray([byte_decoder[c] for c in t])
# Method `convert_ids_to_tokens` should exist on GPT2Tokenizer and other BPE tokenizers.
for t in model.tokenizer.convert_ids_to_tokens(input[0]) # type: ignore
],
"token_ids": input[0].tolist(),
"logits": {
"logits": [l.values.tolist() for l in logits_topk],
"tokens": [
[
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(l.indices)
# Method `convert_ids_to_tokens` should exist on GPT2Tokenizer and other BPE tokenizers.
for t in model.tokenizer.convert_ids_to_tokens(l.indices) # type: ignore
]
for l in logits_topk
],
Expand Down Expand Up @@ -469,10 +484,15 @@ def model_trace(request: ModelTraceRequest):
dictionaries = client.list_dictionaries(dictionary_series=dictionary_series)
assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred."
model = get_model(dictionaries[0])
assert model.tokenizer is not None, "Tokenizer not found"
saes = [(get_sae(name), name) for name in request.saes]
max_feature_acts = {
name: client.get_max_feature_acts(name, dictionary_series=dictionary_series) for _, name in saes
}
assert all(
max_feature_acts is not None for max_feature_acts in max_feature_acts.values()
), "Max feature acts not found"
max_feature_acts = cast(dict[str, dict[int, int]], max_feature_acts)
assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found"
assert all(
tracing.sae in request.saes for tracing in request.tracings if isinstance(tracing, FeatureNode)
Expand Down Expand Up @@ -629,7 +649,9 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint):

result = {
"context": [
bytearray([byte_decoder[c] for c in t]) for t in model.tokenizer.convert_ids_to_tokens(input[0])
bytearray([byte_decoder[c] for c in t])
# Method `convert_ids_to_tokens` should exist on GPT2Tokenizer and other BPE tokenizers.
for t in model.tokenizer.convert_ids_to_tokens(input[0]) # type: ignore
],
"token_ids": input[0].tolist(),
"tracings": tracing_results,
Expand All @@ -645,10 +667,9 @@ def feature_interpretation(
custom_interpretation: str | None = None,
):
model = get_model(dictionary_name)
path = (
client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)["path"]
or f"{result_dir}/{dictionary_name}"
)
dictionary = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)
assert dictionary is not None, "Dictionary not found"
path = dictionary["path"]
if type == "custom":
interpretation: Any = {
"text": custom_interpretation,
Expand All @@ -669,6 +690,7 @@ def feature_interpretation(
}
)
feature = client.get_feature(dictionary_name, feature_index, dictionary_series=dictionary_series)
assert feature is not None, "Feature not found"
result = generate_description(model, feature["analysis"][0], cfg)
interpretation = {
"text": result["response"],
Expand All @@ -685,6 +707,7 @@ def feature_interpretation(
}
)
feature = client.get_feature(dictionary_name, feature_index, dictionary_series=dictionary_series)
assert feature is not None, "Feature not found"
interpretation = feature["interpretation"] if "interpretation" in feature else None
if interpretation is None:
return Response(content="Feature interpretation not found", status_code=404)
Expand Down Expand Up @@ -721,6 +744,8 @@ def feature_interpretation(
"detail": validation_result,
}
)
else:
return Response(content="Invalid interpretation type", status_code=400)

try:
client.update_feature(
Expand Down
1 change: 1 addition & 0 deletions src/lm_saes/activation/activation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def make_activation_dataset(model: HookedTransformer, cfg: ActivationGenerationC

while n_tokens_in_chunk < max_tokens_per_chunk:
tokens = token_source.next(cfg.dataset.store_batch_size)
assert tokens is not None, "Token source returned None"
_, cache = model.run_with_cache_until(tokens, names_filter=cfg.hook_points, until=cfg.hook_points[-1])
for hook_point in cfg.hook_points:
act = cache[hook_point]
Expand Down
11 changes: 8 additions & 3 deletions src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from einops import rearrange, repeat
from transformer_lens import HookedTransformer

from lm_saes.activation.activation_dataset import list_activation_chunks, load_activation_chunk
from lm_saes.activation.activation_dataset import (
list_activation_chunks,
load_activation_chunk,
)
from lm_saes.activation.token_source import TokenSource
from lm_saes.config import ActivationStoreConfig

Expand Down Expand Up @@ -40,6 +43,8 @@ class TokenActivationSource(ActivationSource):
def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig):
self.token_source = TokenSource.from_config(model=model, cfg=cfg.dataset)
self.model = model
assert model.tokenizer is not None, "Tokenizer is not set"
self.tokenizer = model.tokenizer
self.cfg = cfg

def next(self) -> Dict[str, torch.Tensor] | None:
Expand All @@ -53,9 +58,9 @@ def next(self) -> Dict[str, torch.Tensor] | None:
)

filter_mask = torch.logical_and(
tokens.ne(self.model.tokenizer.eos_token_id), tokens.ne(self.model.tokenizer.pad_token_id)
tokens != self.tokenizer.eos_token_id, tokens != self.tokenizer.pad_token_id
)
filter_mask = torch.logical_and(filter_mask, tokens.ne(self.model.tokenizer.bos_token_id))
filter_mask = torch.logical_and(filter_mask, tokens != self.tokenizer.bos_token_id)

filter_mask = rearrange(filter_mask, "b l -> (b l)")

Expand Down
28 changes: 18 additions & 10 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import random
from typing import Any, cast

import datasets
import torch
import torch.distributed as dist
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from transformer_lens import HookedTransformer

from lm_saes.config import TextDatasetConfig
Expand All @@ -22,6 +24,8 @@ def __init__(
):
self.dataloader = dataloader
self.model = model
assert model.tokenizer is not None, "Tokenizer is not set"
self.tokenizer = model.tokenizer
self.is_dataset_tokenized = is_dataset_tokenized
self.concat_tokens = concat_tokens
self.seq_len = seq_len
Expand All @@ -31,25 +35,23 @@ def __init__(

self.token_buffer = torch.empty((0, seq_len), dtype=torch.long, device=self.device)

self.bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id], dtype=torch.long, device=self.device
)
self.bos_token_id_tensor = torch.tensor([self.tokenizer.bos_token_id], dtype=torch.long, device=self.device)
self.resid = torch.tensor([], dtype=torch.long, device=self.device)

self.sample_probs = sample_probs
self.prepend_bos = prepend_bos

def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None:
def fill_with_one_batch(self, batch: dict[str, Any], pack: bool, prepend_bos: bool) -> None:
if self.is_dataset_tokenized:
tokens: torch.Tensor = batch["tokens"].to(self.device)
else:
tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device)
if pack:
while tokens.size(0) > 0:
cur_tokens = tokens[0]
cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.bos_token_id]
cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.eos_token_id]
cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.pad_token_id]
cur_tokens = cur_tokens[cur_tokens != self.tokenizer.bos_token_id]
cur_tokens = cur_tokens[cur_tokens != self.tokenizer.eos_token_id]
cur_tokens = cur_tokens[cur_tokens != self.tokenizer.pad_token_id]

self.resid = torch.cat([self.resid, self.bos_token_id_tensor.clone(), cur_tokens], dim=0)
while self.resid.size(0) >= self.seq_len:
Expand All @@ -65,12 +67,15 @@ def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None:

if tokens.size(1) < self.seq_len:
pad_len = self.seq_len - tokens.size(1)
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = 0 # Default to 0 if pad token not set
tokens = torch.cat(
[
tokens,
torch.full(
(tokens.size(0), pad_len),
self.model.tokenizer.pad_token_id,
pad_token_id,
dtype=torch.long,
device=self.device,
),
Expand Down Expand Up @@ -113,13 +118,16 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig):
dataset = load_dataset(dataset_path, split="train", cache_dir=cfg.cache_dir, keep_in_memory=True)
else:
dataset = load_from_disk(dataset_path, keep_in_memory=True)
dataset = cast(datasets.Dataset, dataset)
if dist.is_initialized():
shard_id = dist.get_rank()
shard = dataset.shard(num_shards=dist.get_world_size(), index=shard_id, contiguous=True)
else:
shard = dataset

dataloader = DataLoader(shard, batch_size=cfg.store_batch_size, pin_memory=True)
dataloader = DataLoader(
dataset=cast(Dataset[dict[str, Any]], shard), batch_size=cfg.store_batch_size, pin_memory=True
)
return dataloader

@staticmethod
Expand Down
Loading

0 comments on commit 8c61d1d

Please sign in to comment.