Skip to content

Commit

Permalink
fix(server): steering
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 18, 2024
1 parent fca4389 commit f0058a5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 72 deletions.
1 change: 0 additions & 1 deletion server/.env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
MONGO_URI= # Must fill in
RESULT_DIR= # Must fill in
DICTIONARY_SERIES= # Must fill in
DICTIONARY_CKPT_NAME=final.pt

Expand Down
49 changes: 41 additions & 8 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def model_generate(request: ModelGenerateRequest):
assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found"

def generate_steering_hook(steering: SteeringConfig):
def steering_hook(tensor: torch.Tensor, hook: HookPoint):
feature_acts = None

def steer(tensor: torch.Tensor):
assert len(tensor.shape) == 3
tensor = tensor.clone()
if steering.steering_type == "times":
Expand All @@ -384,14 +386,28 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint):
tensor[:, :, steering.feature_index] = steering.steering_value
return tensor

def save_feature_acts_hook(tensor: torch.Tensor, hook: HookPoint):
nonlocal feature_acts
feature_acts = tensor
return steer(tensor)

def steering_hook(tensor: torch.Tensor, hook: HookPoint):
assert feature_acts is not None, "Feature acts should be saved before steering"
difference = (steer(feature_acts) - feature_acts) @ sae.decoder.weight.T
tensor += difference.detach()
return tensor

sae = get_sae(steering.sae)
return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook
return [
(f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", save_feature_acts_hook),
(f"{sae.cfg.hook_point_out}", steering_hook),
]

steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings]
steering_hooks = sum([generate_steering_hook(steering) for steering in request.steerings], [])

with torch.no_grad():
with apply_sae(model, [sae for sae, _ in saes]):
with model.hooks(steerings_hooks):
with model.hooks(steering_hooks):
input = (
model.to_tokens(request.input_text, prepend_bos=False)
if isinstance(request.input_text, str)
Expand Down Expand Up @@ -499,7 +515,10 @@ def model_trace(request: ModelTraceRequest):
), "Tracing SAE not found"

def generate_steering_hook(steering: SteeringConfig):
def steering_hook(tensor: torch.Tensor, hook: HookPoint):
feature_acts = None
sae = get_sae(steering.sae)

def steer(tensor: torch.Tensor):
assert len(tensor.shape) == 3
tensor = tensor.clone()
if steering.steering_type == "times":
Expand All @@ -515,17 +534,31 @@ def steering_hook(tensor: torch.Tensor, hook: HookPoint):
tensor[:, :, steering.feature_index] = steering.steering_value
return tensor

def save_feature_acts_hook(tensor: torch.Tensor, hook: HookPoint):
nonlocal feature_acts
feature_acts = tensor
return steer(tensor)

def steering_hook(tensor: torch.Tensor, hook: HookPoint):
assert feature_acts is not None, "Feature acts should be saved before steering"
difference = (steer(feature_acts) - feature_acts) @ sae.decoder.weight.T
tensor += difference.detach()
return tensor

sae = get_sae(steering.sae)
return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook
return [
(f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", save_feature_acts_hook),
(f"{sae.cfg.hook_point_out}", steering_hook),
]

steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings]
steering_hooks = sum([generate_steering_hook(steering) for steering in request.steerings], [])

candidates = [f"{sae.cfg.hook_point_out}.sae.hook_feature_acts" for sae, _ in saes]
if request.detach_at_attn_scores:
candidates += [f"blocks.{i}.attn.hook_attn_scores" for i in range(model.cfg.n_layers)]

with apply_sae(model, [sae for sae, _ in saes]):
with model.hooks(steerings_hooks):
with model.hooks(steering_hooks):
with detach_at(model, candidates):
input = (
model.to_tokens(request.input_text, prepend_bos=False)
Expand Down
63 changes: 0 additions & 63 deletions tests/intergration/test_attributor.py

This file was deleted.

0 comments on commit f0058a5

Please sign in to comment.