-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerate Inference in TransformerLens #26
Changes from 10 commits
0d1220b
a916e86
38af520
e6cec69
20a432d
b85c98e
a12cc22
5447661
90a3361
310193e
c8c86bd
24ac042
b78b52f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,17 @@ | |
import bitsandbytes as bnb | ||
from bitsandbytes.nn.modules import Params4bit | ||
|
||
# From transformers/models/llama/modeling_llama.py | ||
def _get_unpad_data(attention_mask): | ||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
class AbstractAttention(ABC, nn.Module): | ||
alibi: Union[torch.Tensor, None] | ||
|
@@ -96,13 +107,26 @@ def __init__( | |
if self.cfg.scale_attn_by_inverse_layer_idx: | ||
assert self.layer_id is not None # keep mypy happy | ||
self.attn_scale *= self.layer_id + 1 | ||
|
||
self.hook_k = HookPoint() # [batch, pos, head_index, d_head] | ||
self.hook_q = HookPoint() # [batch, pos, head_index, d_head] | ||
self.hook_v = HookPoint() # [batch, pos, head_index, d_head] | ||
self.hook_z = HookPoint() # [batch, pos, head_index, d_head] | ||
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] | ||
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] | ||
|
||
# Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked. | ||
if self.cfg.use_flash_attn: | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
self.flash_attn_func = flash_attn_func | ||
self.flash_attn_varlen_func = flash_attn_varlen_func | ||
self.fa_index_first_axis = index_first_axis | ||
self.fa_pad_input = pad_input | ||
self.fa_unpad_input = unpad_input | ||
else: | ||
self.hook_z = HookPoint() # [batch, pos, head_index, d_head] | ||
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] | ||
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] | ||
|
||
|
||
self.hook_result = HookPoint() # [batch, pos, head_index, d_model] | ||
|
||
# See HookedTransformerConfig for more details. | ||
|
@@ -195,45 +219,72 @@ def forward( | |
self.apply_rotary(k, 0, attention_mask) | ||
) # keys are cached so no offset | ||
|
||
if self.cfg.dtype not in [torch.float32, torch.float64]: | ||
if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16: | ||
# If using 16 bits, increase the precision to avoid numerical instabilities | ||
q = q.to(torch.float32) | ||
k = k.to(torch.float32) | ||
if self.cfg.use_flash_attn: | ||
# use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case. | ||
# Contains at least one padding token in the sequence | ||
causal = True if self.cfg.attention_dir == "causal" else False | ||
if attention_mask is not None: | ||
batch_size, query_length, _ = q.shape | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | ||
q, k, v, attention_mask, q.shape[1] | ||
) | ||
|
||
attn_scores = self.calculate_attention_scores( | ||
q, k | ||
) # [batch, head_index, query_pos, key_pos] | ||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output_unpad = self.flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
causal=causal, | ||
) | ||
|
||
if self.cfg.positional_embedding_type == "alibi": | ||
query_ctx = attn_scores.size(-2) | ||
# The key context length is the number of positions in the past - this includes all positions in the cache | ||
key_ctx = attn_scores.size(-1) | ||
z = self.fa_pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
else: | ||
z = self.flash_attn_func(q, k, v, causal=causal) | ||
else: | ||
attn_scores = self.calculate_attention_scores( | ||
q, k | ||
) # [batch, head_index, query_pos, key_pos] | ||
|
||
# only recompute when necessary to increase efficiency. | ||
if self.alibi is None or key_ctx > self.alibi.size(-1): | ||
self.alibi = AbstractAttention.create_alibi_bias( | ||
self.cfg.n_heads, key_ctx, self.cfg.device | ||
) | ||
if self.cfg.positional_embedding_type == "alibi": | ||
query_ctx = attn_scores.size(-2) | ||
# The key context length is the number of positions in the past - this includes all positions in the cache | ||
key_ctx = attn_scores.size(-1) | ||
|
||
attn_scores += self.alibi[ | ||
:, :query_ctx, :key_ctx | ||
] # [batch, head_index, query_pos, key_pos] | ||
# only recompute when necessary to increase efficiency. | ||
if self.alibi is None or key_ctx > self.alibi.size(-1): | ||
self.alibi = AbstractAttention.create_alibi_bias( | ||
self.cfg.n_heads, key_ctx, self.cfg.device | ||
) | ||
|
||
if self.cfg.attention_dir == "causal": | ||
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. | ||
attn_scores = self.apply_causal_mask( | ||
attn_scores, kv_cache_pos_offset, attention_mask | ||
) # [batch, head_index, query_pos, key_pos] | ||
if additive_attention_mask is not None: | ||
attn_scores += additive_attention_mask | ||
|
||
attn_scores = self.hook_attn_scores(attn_scores) | ||
pattern = F.softmax(attn_scores, dim=-1) | ||
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) | ||
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] | ||
pattern = pattern.to(self.cfg.dtype) | ||
pattern = pattern.to(v.device) | ||
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] | ||
attn_scores += self.alibi[ | ||
:, :query_ctx, :key_ctx | ||
] # [batch, head_index, query_pos, key_pos] | ||
|
||
if self.cfg.attention_dir == "causal": | ||
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. | ||
attn_scores = self.apply_causal_mask( | ||
attn_scores, kv_cache_pos_offset, attention_mask | ||
) # [batch, head_index, query_pos, key_pos] | ||
if additive_attention_mask is not None: | ||
attn_scores += additive_attention_mask | ||
|
||
attn_scores = self.hook_attn_scores(attn_scores) | ||
pattern = F.softmax(attn_scores, dim=-1) | ||
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) | ||
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] | ||
pattern = pattern.to(self.cfg.dtype) | ||
pattern = pattern.to(v.device) | ||
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] | ||
if not self.cfg.use_attn_result: | ||
if self.cfg.load_in_4bit: | ||
# call bitsandbytes method to dequantize and multiply | ||
|
@@ -656,3 +707,41 @@ def create_alibi_bias( | |
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) | ||
|
||
return alibi_bias | ||
|
||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add necessary type hints and comments to this function. |
||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | ||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | ||
|
||
key_layer = self.fa_index_first_axis( | ||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | ||
) | ||
value_layer = self.fa_index_first_axis( | ||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | ||
) | ||
if query_length == kv_seq_len: | ||
query_layer = self.fa_index_first_axis( | ||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k | ||
) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_in_batch_q = max_seqlen_in_batch_k | ||
indices_q = indices_k | ||
elif query_length == 1: | ||
max_seqlen_in_batch_q = 1 | ||
cu_seqlens_q = torch.arange( | ||
batch_size + 1, dtype=torch.int32, device=query_layer.device | ||
) # There is a memcpy here, that is very bad. | ||
indices_q = cu_seqlens_q[:-1] | ||
query_layer = query_layer.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
attention_mask = attention_mask[:, -query_length:] | ||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = self.fa_unpad_input(query_layer, attention_mask) | ||
|
||
return ( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
indices_q, | ||
(cu_seqlens_q, cu_seqlens_k), | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ use_ghost_grads = true | |
|
||
[lm] | ||
model_name = "gpt2" | ||
use_flash_attn = false | ||
d_model = 768 | ||
|
||
[dataset] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,4 +59,9 @@ check_untyped_defs=true | |
exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"] | ||
ignore_missing_imports=true | ||
allow_redefinition=true | ||
implicit_optional=true | ||
implicit_optional=true | ||
|
||
[build-system] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please explain why these requirements are necessary. |
||
requires = ["pdm-pep517"] | ||
build-backend = "pdm.pep517.api" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this file testing if configs can be successfully created? If true, it seems better to try creating several hard-coded configs instead of depending on command line arguments for the sake of automated testing. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
import pytest | ||
|
||
from lm_saes.config import LanguageModelConfig | ||
from lm_saes.runner import language_model_sae_runner | ||
|
||
def pytest_addoption(parser): | ||
parser.addoption("--layer", nargs="*", type=int, required=False, help='Layer number') | ||
parser.addoption("--batch_size", type=int, required=False, default=4096, help='Batchsize, default 4096') | ||
parser.addoption("--lr", type=float, required=False, default=8e-5, help='Learning rate, default 8e-5') | ||
parser.addoption("--expdir", type=str, required=False, default="path/to/results", help='Export directory, default path') | ||
parser.addoption("--useddp", type=bool, required=False, default=False, help='If using distributed method, default False') | ||
parser.addoption('--attn_type', type=str, required=False, choices=['flash', 'normal'], default="flash", help='Use or not use log of wandb, default True') | ||
parser.addoption('--dtype', type=str, required=False, choices=['fp32', 'bfp16'], default="fp32", help='Dtype, default fp32') | ||
parser.addoption('--model_name', type=str, required=False, default="meta-llama/Meta-Llama-3-8B", help='Supported model name of TransformerLens, default gpt2') | ||
parser.addoption('--d_model', type=int, required=False, default=4096, help='Dimension of model hidden states, default 4096') | ||
parser.addoption('--model_path', type=str, required=False, default="path/to/model", help='Hugging-face model path used to load.') | ||
|
||
@pytest.fixture | ||
def args(request): | ||
return {"layer":request.config.getoption("--layer"), | ||
"batch_size":request.config.getoption("--batch_size"), | ||
"lr":request.config.getoption("--lr"), | ||
"expdir":request.config.getoption("--expdir"), | ||
"useddp":request.config.getoption("--useddp"), | ||
"attn_type":request.config.getoption("--attn_type"), | ||
"dtype":request.config.getoption("--dtype"), | ||
"model_name":request.config.getoption("--model_name"), | ||
"model_path":request.config.getoption("--model_path"), | ||
"d_model":request.config.getoption("--d_model"), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explain why excluding
torch.bfloat16
. Besides,torch.bfloat16
could be put inside the exclusion lists.