Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention support. #20152

Merged
merged 13 commits into from
Oct 8, 2024
Merged
18 changes: 17 additions & 1 deletion keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,14 @@ def _dot_product_attention_core(


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
query = convert_to_tensor(query)
key = convert_to_tensor(key)
Expand All @@ -1000,6 +1007,7 @@ def dot_product_attention(

# `dot_product_attention` is only available in jax>=0.4.31
if hasattr(jax.nn, "dot_product_attention"):
implementation = "cudnn" if flash_attention else "xla"
return jax.nn.dot_product_attention(
query,
key,
Expand All @@ -1008,6 +1016,14 @@ def dot_product_attention(
mask=mask,
scale=scale,
is_causal=is_causal,
implementation=implementation,
)

if flash_attention:
raise ValueError(
"Flash attention is not supported in your "
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
)

# Ref: jax.nn.dot_product_attention
Expand Down
11 changes: 10 additions & 1 deletion keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,17 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if flash_attention:
raise ValueError("Flash attention is not implemented in NumPy.")
# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
Expand Down
14 changes: 13 additions & 1 deletion keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,20 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if flash_attention:
raise ValueError(
"Flash attention is not supported yet in TensorFlow backend."
)

# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
Expand Down
56 changes: 52 additions & 4 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,28 @@ def _get_large_negative(dtype):
return convert_to_tensor(val * -0.7, dtype=dtype)


def is_flash_attention_enabled(query, key, value, mask=None, is_causal=False):
params = torch.backends.cuda.SDPAParams(
query,
key,
value,
mask,
0.0,
is_causal,
)
is_enabled = torch.backends.cuda.can_use_flash_attention(params, False)
return is_enabled


def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if bias is not None:
raise ValueError(
Expand All @@ -891,7 +911,35 @@ def dot_product_attention(
query = torch.transpose(query, axis0, axis1)
key = torch.transpose(key, axis0, axis1)
value = torch.transpose(value, axis0, axis1)
attention_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale
)

if flash_attention:
is_enabled = is_flash_attention_enabled(
query=query,
key=key,
value=value,
mask=mask,
is_causal=is_causal,
)
if not is_enabled:
raise ValueError(
"Flash attention is not enabled in `torch` backend. "
"The dtype of the inputs should be float16/bfloat16 "
"and your GPU should support flash attention implementation."
)

with torch.nn.attention.sdpa_kernel(
backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION],
):
attention_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
is_causal=is_causal,
scale=scale,
)
else:
attention_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale
)
return torch.transpose(attention_output, axis1, axis0)
32 changes: 29 additions & 3 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,7 +2131,16 @@ def __init__(self, is_causal=False):
super().__init__()
self.is_causal = is_causal

def call(self, query, key, value, bias=None, mask=None, scale=None):
def call(
self,
query,
key,
value,
bias=None,
mask=None,
scale=None,
flash_attention=False,
):
return backend.nn.dot_product_attention(
query,
key,
Expand All @@ -2140,10 +2149,18 @@ def call(self, query, key, value, bias=None, mask=None, scale=None):
mask=mask,
scale=scale,
is_causal=self.is_causal,
flash_attention=flash_attention,
)

def compute_output_spec(
self, query, key, value, bias=None, mask=None, scale=None
self,
query,
key,
value,
bias=None,
mask=None,
scale=None,
flash_attention=False,
):
return KerasTensor(query.shape, dtype=query.dtype)

Expand All @@ -2152,7 +2169,14 @@ def compute_output_spec(
["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"]
)
def dot_product_attention(
query, key, value, bias=None, mask=None, scale=None, is_causal=False
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
"""Scaled dot product attention function.

Expand Down Expand Up @@ -2207,6 +2231,7 @@ def dot_product_attention(
bias=bias,
mask=mask,
scale=scale,
flash_attention=flash_attention,
)
return backend.nn.dot_product_attention(
query,
Expand All @@ -2216,4 +2241,5 @@ def dot_product_attention(
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)
65 changes: 55 additions & 10 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,9 +2208,12 @@ def test_psnr(self):
bias=(None, True),
scale=(None, 1.0),
mask_and_is_causal=((None, False), (True, False), (None, True)),
flash_attention=(True, False),
)
)
def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
def test_dot_product_attention(
self, bias, scale, mask_and_is_causal, flash_attention
):
mask, is_causal = mask_and_is_causal
query_shape = (2, 3, 4, 5)
key_shape = (2, 6, 4, 5)
Expand All @@ -2232,6 +2235,57 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
mask_shape
)

if flash_attention and backend.backend() in [
"torch",
"tensorflow",
"numpy",
]:
self.skipTest(
"Not supported in TF and NumPy and supported for "
"PyTorch with specific requirements."
)

if flash_attention and backend.backend() == "jax":
try:
outputs = knn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)
except ValueError as e:
if e.args[0].startswith(
"Flash attention is not supported in your "
"current JAX version"
):
self.skipTest(
"JAX version does not have "
"`dot_product_attention` function."
)
except RuntimeError as e:
if e.args[0] == "cuDNN is not detected.":
self.skipTest("No CuDNN to run flash attention for JAX.")
elif e.args[0] == "Require at least Ampere arch to run":
self.skipTest(
"Requires at least Ampere arch to run flash attention "
"for JAX."
)
else:
outputs = knn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)

expected = _dot_product_attention(
query,
key,
Expand All @@ -2241,15 +2295,6 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
scale=scale,
is_causal=is_causal,
)
outputs = knn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
)
self.assertAllClose(outputs, expected)


Expand Down