Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention support. #20152

Merged
merged 13 commits into from
Oct 8, 2024
56 changes: 56 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,59 @@ def psnr(x1, x2, max_val):
mse = torch.mean((x1 - x2) ** 2)
psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse)
return psnr


def make_causal_mask(size, dtype, device):
mask = torch.tril(torch.ones((size, size), dtype=dtype, device=device))
mask = mask.masked_fill(mask.logical_not(), -torch.inf)
mask = mask.view((1, 1, size, size))
return mask


def flash_attention(
query, key, value, attn_mask=None, dropout=0.0, is_causal=False, scale=None
):
if query.ndim < 4:
raise ValueError(
"Expected `query` to have 4 dims. " f"Received: {query.ndim}."
)

if key.ndim < 4:
raise ValueError(
"Expected `key` to have 4 dims. " f"Received: {key.ndim}."
)

if value.ndim < 4:
raise ValueError(
"Expected `value` to have 4 dims. " f"Received: {value.ndim}."
)

flash_attn_backend = [torch.nn.attention.SDPBackend.FLASH_ATTENTION]
mask = None
if is_causal:
# We manually create the causal mask here instead of setting
# `is_causal=True` in the PyTorch function
# because it will not accept attention mask if we did that.
mask = make_causal_mask(query.shape[-2], query.dtype, query.device)

if attn_mask is not None and mask is not None:
attn_mask = convert_to_tensor(attn_mask)
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(
(attn_mask.shape[0], 1, 1, attn_mask.shape[1])
)

attn_mask = attn_mask.masked_fill(attn_mask.logical_not(), -torch.inf)
mask += attn_mask

with torch.nn.attention.sdpa_kernel(backends=flash_attn_backend):
output = tnn.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=mask,
dropout_p=dropout,
is_causal=False,
scale=scale,
)
return output