Skip to content

Commit

Permalink
Merge pull request #19 from OpenMOSS/opt-memory-speed
Browse files Browse the repository at this point in the history
Resolves #10
  • Loading branch information
dest1n1s authored Jun 8, 2024
2 parents d3865fc + f7f23e6 commit 05a5973
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 27 deletions.
35 changes: 35 additions & 0 deletions TransformerLens/tests/acceptance/test_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from transformer_lens import HookedTransformer
import torch

MODEL = "solu-2l"

def time_diff(func):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

func()

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event)

@torch.no_grad()
def test_offload_params_after():
model = HookedTransformer.from_pretrained(MODEL, device="cuda")
allocated_before = torch.cuda.memory_allocated(0)
model.offload_params_after("blocks.0.hook_resid_post", torch.tensor([[0]], device="cuda"))
allocated_after = torch.cuda.memory_allocated(0)
assert allocated_after < allocated_before * 0.55

@torch.no_grad()
def test_run_with_cache_until():
model = HookedTransformer.from_pretrained(MODEL, device="cuda")
def forward():
model.run_with_cache("Hello, world!", names_filter=["blocks.0.hook_resid_post"])
forward_time = time_diff(forward)
def forward_until():
model.run_with_cache_until("Hello, world!", names_filter=["blocks.0.hook_resid_post"])
forward_fake_time = time_diff(forward_until)
assert forward_fake_time < forward_time * 0.7

52 changes: 52 additions & 0 deletions TransformerLens/tests/integration/test_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from transformer_lens.hook_points import HookedRootModule, HookPoint
import torch
import torch.nn as nn

class Block(nn.Module):
def __init__(self):
super().__init__()
self.subblock1 = nn.Linear(10, 10)
self.subblock2 = nn.Linear(10, 10)
self.activation = nn.ReLU()
self.hook_mid = HookPoint()

def forward(self, x):
return self.subblock2(self.hook_mid(self.activation(self.subblock1(x))))

class TestModule(HookedRootModule):
__test__ = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = nn.ModuleList([Block() for _ in range(3)])
self.embed = nn.Linear(1, 10)
self.unembed = nn.Linear(10, 1)
self.setup()

def forward(self, x):
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.unembed(x)

def test_run_with_cache_until():
model = TestModule()
_, cache_before = model.run_with_cache(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"])
out, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"])

assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"])
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"])
assert torch.allclose(cache_before["blocks.1.hook_mid"], out)

def test_offload_params_after():
model = TestModule()
_, cache_before = model.run_with_cache(torch.tensor([1.]))

model.offload_params_after("blocks.1.hook_mid", torch.tensor([1.]))
assert model.blocks[0].subblock1.weight is not None
assert model.blocks[1].subblock1.weight is not None
assert model.blocks[2].subblock1.weight is None
assert model.unembed.weight is None

_, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"])
assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"])
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"])
126 changes: 123 additions & 3 deletions TransformerLens/transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch.nn as nn
import torch.utils.hooks as hooks

from transformer_lens.utils import Slice, SliceInput
from transformer_lens.utils import Slice, SliceInput, set_nested_attr


@dataclass
Expand All @@ -51,8 +51,7 @@ class LensHandle:
class _HookFunctionProtocol(Protocol):
"""Protocol for hook functions."""

def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]:
...
def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]: ...


HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
Expand Down Expand Up @@ -776,5 +775,126 @@ def run_with_ref_cache(

return model_out, cache_dict

def offload_params_after(self, last_hook: str, *model_args, **model_kwargs):
"""
Set parameters that are not used after a certain hook to None.
This does not guarantee that all parameters are offloaded, but it should offload most of them.
Specifically, the direct parameters of the ancestor modules of the last hook are not offloaded,
since there are no way to know whether they are used before or after the last hook.
Args:
last_hook (str): The name of the last hook.
*model_args: Positional arguments for the model.
**model_kwargs: Keyword arguments for the model.
"""
pass_module_list: List[nn.Module] = []
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
hook_handles: List[hooks.RemovableHandle] = []

def pass_hook(module: nn.Module, module_input: Any, module_output: Any):
pass_module_list.append(module)

def convert_hook(tensor: torch.Tensor, hook: HookPoint):
pass_param_set = set()
hook_ancestors = [module for module in self.modules() if module == self or hook.name.startswith(module.name)]
for module in pass_module_list + hook_ancestors:
for param_name, parameters in module.named_parameters():
if "." not in param_name:
name = f"{module.name}.{param_name}" if module != self else param_name
pass_param_set.add(name)

fake_param_set = set([name for name, _ in self.named_parameters()]).difference(pass_param_set)

for name in fake_param_set:
set_nested_attr(self, name, None)
raise StopIteration


for _, module in self.named_modules():
hook_handles.append(module.register_forward_hook(pass_hook))

with fake_mode:
with self.hooks(fwd_hooks=[(last_hook, convert_hook)]):
try:
self(*model_args, **model_kwargs)
except StopIteration:
pass

for handle in hook_handles:
handle.remove()

def run_with_cache_until(
self,
*model_args: Any,
names_filter: NamesFilter = None,
until: str = None,
device: DeviceType = None,
remove_batch_dim: bool = False,
reset_hooks_end: bool = True,
clear_contexts: bool = False,
pos_slice: Optional[Union[Slice, SliceInput]] = None,
**model_kwargs: Any,
):
"""
Runs the model and returns the model output and a Cache object.
Args:
*model_args: Positional arguments for the model.
names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
list of str, or a function that takes a string and returns a bool. Defaults to None, which
means cache everything.
until (str, optional): The name of the hook to stop caching at. Defaults to None, which means
stop caching at the last hook.
device (str or torch.Device, optional): The device to cache activations on. Defaults to the
model device. WARNING: Setting a different device than the one used by the model leads to
significant performance degradation.
remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
makes sense with batch_size=1 inputs. Defaults to False.
reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
end of the run. Defaults to True.
clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
Defaults to False.
pos_slice:
The slice to apply to the cache output. Defaults to None, do nothing.
**model_kwargs: Keyword arguments for the model.
Returns:
tuple: A tuple containing the model output and a Cache object.
"""

pos_slice = Slice.unwrap(pos_slice)

cache_dict, fwd, _ = self.get_caching_hooks(
names_filter,
False,
device,
remove_batch_dim=remove_batch_dim,
pos_slice=pos_slice,
)

if until is None:
until = fwd[-1][0]

class ModuleStop(Exception):
def __init__(self, tensor: torch.Tensor):
self.tensor = tensor

def stop_hook(tensor: torch.Tensor, hook: HookPoint):
if hook.name == until:
raise ModuleStop(tensor)

with self.hooks(
fwd_hooks=fwd + [(until, stop_hook)],
reset_hooks_end=reset_hooks_end,
clear_contexts=clear_contexts,
):
try:
model_out = self(*model_args, **model_kwargs)
except ModuleStop as e:
model_out = e.tensor

return model_out, cache_dict


# %%
2 changes: 1 addition & 1 deletion examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
store_batch_size = 32, # The batch size for loading the corpus.

# ActivationStoreConfig
hook_points = ["blocks.3.hook_mlp_out"], # The hook point to extract the activations, i.e. the layer output of which is used for training/evaluating the dictionary.
hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`)
n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary.

Expand Down
6 changes: 3 additions & 3 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_model(dictionary_name: str) -> HookedTransformer:
else cfg.model_from_pretrained_path
),
trust_remote_code=True,
use_fast=True,
use_fast=False,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
Expand Down Expand Up @@ -240,7 +240,7 @@ def feature_activation_custom_input(
model = get_model(dictionary_name)
with torch.no_grad():
input = model.to_tokens(input_text, prepend_bos=False)
_, cache = model.run_with_cache(input, names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out])
_, cache = model.run_with_cache_until(input, names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out], until=sae.cfg.hook_point_out)

feature_acts = sae.encode(cache[sae.cfg.hook_point_in][0], label=cache[sae.cfg.hook_point_out][0])
sample = {
Expand Down Expand Up @@ -269,7 +269,7 @@ def dictionary_custom_input(dictionary_name: str, input_text: str):

with torch.no_grad():
input = model.to_tokens(input_text, prepend_bos=False)
_, cache = model.run_with_cache(input, names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out])
_, cache = model.run_with_cache_until(input, names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out], until=sae.cfg.hook_point_out)

feature_acts = sae.encode(cache[sae.cfg.hook_point_in][0], label=cache[sae.cfg.hook_point_out][0])
sample = {
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/activation/activation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def make_activation_dataset(

while n_tokens_in_chunk < max_tokens_per_chunk:
tokens = token_source.next(cfg.store_batch_size)
_, cache = model.run_with_cache(tokens, names_filter=cfg.hook_points)
_, cache = model.run_with_cache_until(tokens, names_filter=cfg.hook_points, until=cfg.hook_points[-1])
for hook_point in cfg.hook_points:
act = cache[hook_point]
act_dict[hook_point] = torch.cat([act_dict[hook_point], act], dim=0)
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def next(self) -> Dict[str, torch.Tensor] | None:
if tokens is None:
return None
with torch.no_grad():
_, cache = self.model.run_with_cache(tokens, names_filter=self.cfg.hook_points)
_, cache = self.model.run_with_cache_until(tokens, names_filter=self.cfg.hook_points, until=self.cfg.hook_points[-1])

filter_mask = torch.logical_and(tokens.ne(self.model.tokenizer.eos_token_id), tokens.ne(self.model.tokenizer.pad_token_id))
filter_mask = torch.logical_and(filter_mask, tokens.ne(self.model.tokenizer.bos_token_id))
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/analysis/auto_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def check_description(
cost = _calculate_cost(input_tokens, output_tokens)
input_index, input_text = index, response
input_token = model.to_tokens(input_text)
_, cache = model.run_with_cache(input_token, names_filter=[cfg.hook_point_in, cfg.hook_point_out])
_, cache = model.run_with_cache_until(input_token, names_filter=[cfg.hook_point_in, cfg.hook_point_out], until=cfg.hook_point_out)
activation_in, activation_out = cache[cfg.hook_point_in][0], cache[cfg.hook_point_out][0]
feature_acts = sae.encode(activation_in, label=activation_out)
max_value, max_pos = torch.max(feature_acts, dim=0)
Expand Down
11 changes: 2 additions & 9 deletions src/lm_saes/analysis/sample_feature_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def sample_feature_activations(
start_index = sae_chunk_id * d_sae
end_index = (sae_chunk_id + 1) * d_sae

hook_point_out = cfg.hook_point_out
stop_at_layer = int(hook_point_out.split(".")[1]) + 1 # fuck this hard code

sample_result = {k: {
"elt": torch.empty((0, d_sae), dtype=cfg.dtype, device=cfg.device),
"feature_acts": torch.empty((0, d_sae, cfg.context_size), dtype=cfg.dtype, device=cfg.device),
Expand All @@ -63,14 +60,10 @@ def sample_feature_activations(
if batch is None:
raise ValueError("Not enough tokens to sample")

_, cache = model.run_with_cache(batch, names_filter=[cfg.hook_point_in, cfg.hook_point_out], stop_at_layer=stop_at_layer)
_, cache = model.run_with_cache_until(batch, names_filter=[cfg.hook_point_in, cfg.hook_point_out], until=cfg.hook_point_out)
activation_in, activation_out = cache[cfg.hook_point_in], cache[cfg.hook_point_out]

(
_,
(_, aux_data),
) = sae.forward(activation_in, label=activation_out)
feature_acts = aux_data["feature_acts"][..., start_index: end_index]
feature_acts = sae.encode(activation_in, label=activation_out)[..., start_index: end_index]

act_times += feature_acts.gt(0.0).sum(dim=[0, 1])

Expand Down
Loading

0 comments on commit 05a5973

Please sign in to comment.