Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Inference in TransformerLens #26

Merged
merged 13 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions TransformerLens/transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def from_pretrained(
cls,
model_name: str,
fold_ln: bool = True,
use_flash_attn: bool = False,
center_writing_weights: bool = True,
center_unembed: bool = True,
refactor_factored_attn_matrices: bool = False,
Expand Down Expand Up @@ -1240,6 +1241,7 @@ def from_pretrained(
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
use_flash_attn=use_flash_attn,
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
Expand Down
3 changes: 3 additions & 0 deletions TransformerLens/transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class HookedTransformerConfig:
custom config, if loading from pretrained then this is not needed.
use_local_attn (bool): whether to use local attention - ie each
destination token can only attend to source tokens a certain distance back.
use_flash_attn (bool): whether to use FlashAttention-2. Please refer to
https://github.com/Dao-AILab/flash-attention.
window_size (int, *optional*): the size of the window for local
attention
attn_types (List[str], *optional*): the types of attention to use for
Expand Down Expand Up @@ -177,6 +179,7 @@ class HookedTransformerConfig:
use_hook_mlp_in: bool = False
use_attn_in: bool = False
use_local_attn: bool = False
use_flash_attn: bool = False
original_architecture: Optional[str] = None
from_checkpoint: bool = False
checkpoint_index: Optional[int] = None
Expand Down
159 changes: 124 additions & 35 deletions TransformerLens/transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit

# From transformers/models/llama/modeling_llama.py
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)

class AbstractAttention(ABC, nn.Module):
alibi: Union[torch.Tensor, None]
Expand Down Expand Up @@ -96,13 +107,26 @@ def __init__(
if self.cfg.scale_attn_by_inverse_layer_idx:
assert self.layer_id is not None # keep mypy happy
self.attn_scale *= self.layer_id + 1

self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]

# Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked.
if self.cfg.use_flash_attn:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
self.flash_attn_func = flash_attn_func
self.flash_attn_varlen_func = flash_attn_varlen_func
self.fa_index_first_axis = index_first_axis
self.fa_pad_input = pad_input
self.fa_unpad_input = unpad_input
else:
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]


self.hook_result = HookPoint() # [batch, pos, head_index, d_model]

# See HookedTransformerConfig for more details.
Expand Down Expand Up @@ -195,45 +219,72 @@ def forward(
self.apply_rotary(k, 0, attention_mask)
) # keys are cached so no offset

if self.cfg.dtype not in [torch.float32, torch.float64]:
if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain why excluding torch.bfloat16. Besides, torch.bfloat16 could be put inside the exclusion lists.

# If using 16 bits, increase the precision to avoid numerical instabilities
q = q.to(torch.float32)
k = k.to(torch.float32)
if self.cfg.use_flash_attn:
# use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case.
# Contains at least one padding token in the sequence
causal = True if self.cfg.attention_dir == "causal" else False
if attention_mask is not None:
batch_size, query_length, _ = q.shape
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
q, k, v, attention_mask, q.shape[1]
)

attn_scores = self.calculate_attention_scores(
q, k
) # [batch, head_index, query_pos, key_pos]
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = self.flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
causal=causal,
)

if self.cfg.positional_embedding_type == "alibi":
query_ctx = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
key_ctx = attn_scores.size(-1)
z = self.fa_pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
z = self.flash_attn_func(q, k, v, causal=causal)
else:
attn_scores = self.calculate_attention_scores(
q, k
) # [batch, head_index, query_pos, key_pos]

# only recompute when necessary to increase efficiency.
if self.alibi is None or key_ctx > self.alibi.size(-1):
self.alibi = AbstractAttention.create_alibi_bias(
self.cfg.n_heads, key_ctx, self.cfg.device
)
if self.cfg.positional_embedding_type == "alibi":
query_ctx = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
key_ctx = attn_scores.size(-1)

attn_scores += self.alibi[
:, :query_ctx, :key_ctx
] # [batch, head_index, query_pos, key_pos]
# only recompute when necessary to increase efficiency.
if self.alibi is None or key_ctx > self.alibi.size(-1):
self.alibi = AbstractAttention.create_alibi_bias(
self.cfg.n_heads, key_ctx, self.cfg.device
)

if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
attn_scores, kv_cache_pos_offset, attention_mask
) # [batch, head_index, query_pos, key_pos]
if additive_attention_mask is not None:
attn_scores += additive_attention_mask

attn_scores = self.hook_attn_scores(attn_scores)
pattern = F.softmax(attn_scores, dim=-1)
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
attn_scores += self.alibi[
:, :query_ctx, :key_ctx
] # [batch, head_index, query_pos, key_pos]

if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
attn_scores, kv_cache_pos_offset, attention_mask
) # [batch, head_index, query_pos, key_pos]
if additive_attention_mask is not None:
attn_scores += additive_attention_mask

attn_scores = self.hook_attn_scores(attn_scores)
pattern = F.softmax(attn_scores, dim=-1)
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
if self.cfg.load_in_4bit:
# call bitsandbytes method to dequantize and multiply
Expand Down Expand Up @@ -656,3 +707,41 @@ def create_alibi_bias(
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)

return alibi_bias

def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add necessary type hints and comments to this function.

indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = self.fa_index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = self.fa_index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = self.fa_index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = self.fa_unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
5 changes: 5 additions & 0 deletions TransformerLens/transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,7 @@ def get_pretrained_model_config(
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
fold_ln: bool = False,
use_flash_attn: bool = False,
device: Optional[Union[str, torch.device]] = None,
n_devices: int = 1,
default_prepend_bos: bool = True,
Expand Down Expand Up @@ -1251,6 +1252,8 @@ def get_pretrained_model_config(
fold_ln (bool, optional): Whether to fold the layer norm into the
subsequent linear layers (see HookedTransformer.fold_layer_norm for
details). Defaults to False.
use_flash_attn (bool): whether to use FlashAttention-2. Please refer to
https://github.com/Dao-AILab/flash-attention. Defaults to False.
device (str, optional): The device to load the model onto. By
default will load to CUDA if available, else CPU.
n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
Expand Down Expand Up @@ -1310,6 +1313,8 @@ def get_pretrained_model_config(
cfg_dict["normalization_type"] = "RMSPre"
else:
logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
if use_flash_attn:
cfg_dict["use_flash_attn"] = True

if checkpoint_index is not None or checkpoint_value is not None:
checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
Expand Down
1 change: 1 addition & 0 deletions examples/configuration/analyze.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ exp_result_dir = "results"
[lm]
model_name = "gpt2"
d_model = 768
use_flash_attn = false

[dataset]
dataset_path = "openwebtext"
Expand Down
1 change: 1 addition & 0 deletions examples/configuration/prune.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ decoder_norm_threshold = 0.99
[lm]
model_name = "gpt2"
d_model = 768
use_flash_attn = false

[dataset]
dataset_path = "openwebtext"
Expand Down
1 change: 1 addition & 0 deletions examples/configuration/train.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use_ghost_grads = true

[lm]
model_name = "gpt2"
use_flash_attn = false
d_model = 768

[dataset]
Expand Down
1 change: 1 addition & 0 deletions examples/programmatic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# LanguageModelConfig
model_name = "gpt2", # The model name or path for the pre-trained model.
d_model = 768, # The hidden size of the model.
use_flash_attn = False, # Whether to use FlashAttentionV2

# TextDatasetConfig
dataset_path = 'Skylion007/OpenWebText', # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,9 @@ check_untyped_defs=true
exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"]
ignore_missing_imports=true
allow_redefinition=true
implicit_optional=true
implicit_optional=true

[build-system]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain why these requirements are necessary.

requires = ["pdm-pep517"]
build-backend = "pdm.pep517.api"

1 change: 1 addition & 0 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __post_init__(self):
class LanguageModelConfig(BaseModelConfig):
model_name: str = "gpt2"
model_from_pretrained_path: Optional[str] = None
use_flash_attn: bool = False
cache_dir: Optional[str] = None
d_model: int = 768
local_files_only: bool = False
Expand Down
6 changes: 6 additions & 0 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):

model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
Expand Down Expand Up @@ -143,6 +144,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig):
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
Expand Down Expand Up @@ -211,6 +213,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig):
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
Expand Down Expand Up @@ -274,6 +277,7 @@ def activation_generation_runner(cfg: ActivationGenerationConfig):
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
Expand Down Expand Up @@ -309,6 +313,7 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig):
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
Expand Down Expand Up @@ -377,6 +382,7 @@ def features_to_logits_runner(cfg: FeaturesDecoderConfig):
)
model = HookedTransformer.from_pretrained(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
Expand Down
31 changes: 31 additions & 0 deletions tests/conftest.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file testing if configs can be successfully created? If true, it seems better to try creating several hard-coded configs instead of depending on command line arguments for the sake of automated testing.

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import pytest

from lm_saes.config import LanguageModelConfig
from lm_saes.runner import language_model_sae_runner

def pytest_addoption(parser):
parser.addoption("--layer", nargs="*", type=int, required=False, help='Layer number')
parser.addoption("--batch_size", type=int, required=False, default=4096, help='Batchsize, default 4096')
parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5')
parser.addoption("--expdir", type=str, required=False, default="path/to/results", help='Export directory, default path')
parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False')
parser.addoption('--attn_type', type=str, required=False, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True')
parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32')
parser.addoption('--model_name', type=str, required=False, default="meta-llama/Meta-Llama-3-8B", help='Supported model name of TransformerLens, default gpt2')
parser.addoption('--d_model', type=int, required=False, default=4096, help='Dimension of model hidden states, default 4096')
parser.addoption('--model_path', type=str, required=False, default="path/to/model", help='Hugging-face model path used to load.')

@pytest.fixture
def args(request):
return {"layer":request.config.getoption("--layer"),
"batch_size":request.config.getoption("--batch_size"),
"lr":request.config.getoption("--lr"),
"expdir":request.config.getoption("--expdir"),
"useddp":request.config.getoption("--useddp"),
"attn_type":request.config.getoption("--attn_type"),
"dtype":request.config.getoption("--dtype"),
"model_name":request.config.getoption("--model_name"),
"model_path":request.config.getoption("--model_path"),
"d_model":request.config.getoption("--d_model"),
}
Loading