From 22f47b18cf2c86e3f8c00d3e3efe1f8a8a43b04c Mon Sep 17 00:00:00 2001 From: Hazem Date: Thu, 22 Aug 2024 21:39:01 +0300 Subject: [PATCH 01/12] added flash attention support for pytorch --- keras/src/backend/torch/nn.py | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index de931db47d4..cae509b2ad0 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -853,3 +853,56 @@ def psnr(x1, x2, max_val): mse = torch.mean((x1 - x2) ** 2) psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse) return psnr + + +def make_causal_mask(size, dtype, device): + mask = torch.tril(torch.ones((size, size), dtype=dtype, device=device)) + mask = mask.masked_fill(mask.logical_not(), -torch.inf) + mask = mask.view((1, 1, size, size)) + return mask + + +def flash_attention( + query, key, value, attn_mask=None, dropout=0.0, is_causal=False, scale=None +): + if query.ndim < 4: + raise ValueError( + "Expected `query` to have 4 dims. " f"Received: {query.ndim}." + ) + + if key.ndim < 4: + raise ValueError( + "Expected `key` to have 4 dims. " f"Received: {key.ndim}." + ) + + if value.ndim < 4: + raise ValueError( + "Expected `value` to have 4 dims. " f"Received: {value.ndim}." + ) + + flash_attn_backend = [torch.nn.attention.SDPBackend.FLASH_ATTENTION] + mask = None + if is_causal: + mask = make_causal_mask(query.shape[-2], query.dtype, query.device) + + if attn_mask is not None and mask is not None: + attn_mask = convert_to_tensor(attn_mask) + if attn_mask.ndim == 2: + attn_mask = attn_mask.view( + (attn_mask.shape[0], 1, 1, attn_mask.shape[1]) + ) + + attn_mask = attn_mask.masked_fill(attn_mask.logical_not(), -torch.inf) + mask += attn_mask + + with torch.nn.attention.sdpa_kernel(backends=flash_attn_backend): + output = tnn.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=mask, + dropout_p=dropout, + is_causal=False, + scale=scale, + ) + return output From 7e99f0613f734bfe5a92b2cde8365493f46c40e5 Mon Sep 17 00:00:00 2001 From: Hazem Date: Thu, 22 Aug 2024 21:41:45 +0300 Subject: [PATCH 02/12] added a comment explaining why the causal mask is created manually --- keras/src/backend/torch/nn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index cae509b2ad0..2ad4bfddc15 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -883,6 +883,9 @@ def flash_attention( flash_attn_backend = [torch.nn.attention.SDPBackend.FLASH_ATTENTION] mask = None if is_causal: + # We manually create the causal mask here instead of setting + # `is_causal=True` in the PyTorch function + # because it will not accept attention mask if we did that. mask = make_causal_mask(query.shape[-2], query.dtype, query.device) if attn_mask is not None and mask is not None: From f3a05b0276fd6c49159ba40182439eb0ddbcfd1a Mon Sep 17 00:00:00 2001 From: Hazem Date: Fri, 4 Oct 2024 14:05:37 +0300 Subject: [PATCH 03/12] added unit tests for flash attention --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/_tf_keras/keras/ops/nn/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/api/ops/nn/__init__.py | 1 + keras/src/backend/torch/nn.py | 75 +++++++--- keras/src/ops/nn.py | 137 +++++++++++++++++++ keras/src/ops/nn_test.py | 112 +++++++++++++++ 7 files changed, 308 insertions(+), 20 deletions(-) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 20cf46889d2..382cd129cde 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -69,6 +69,7 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu +from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index adce3312860..f7e2fc2a2f9 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -15,6 +15,7 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu +from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 20cf46889d2..382cd129cde 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -69,6 +69,7 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu +from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index adce3312860..f7e2fc2a2f9 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -15,6 +15,7 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu +from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 3ef9e7de39f..680555ba601 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -854,63 +854,98 @@ def psnr(x1, x2, max_val): psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse) return psnr - + def make_causal_mask(size, dtype, device): - mask = torch.tril(torch.ones((size, size), dtype=dtype, device=device)) - mask = mask.masked_fill(mask.logical_not(), -torch.inf) + # 1 is subtracted to invert the lower triangle with the upper one. + mask = 1 - torch.tril(torch.ones((size, size), dtype=dtype, device=device)) + mask = mask.masked_fill(mask.bool(), -torch.inf) mask = mask.view((1, 1, size, size)) return mask +def merge_masks(attn_mask, input_mask): + if attn_mask is None: + return input_mask + return attn_mask + input_mask + + def flash_attention( - query, key, value, attn_mask=None, dropout=0.0, is_causal=False, scale=None + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + dropout=0.0, ): - if query.ndim < 4: + if query.ndim != 4: raise ValueError( "Expected `query` to have 4 dims. " f"Received: {query.ndim}." ) - if key.ndim < 4: + if key.ndim != 4: raise ValueError( "Expected `key` to have 4 dims. " f"Received: {key.ndim}." ) - if value.ndim < 4: + if value.ndim != 4: + raise ValueError( + f"Expected `value` to have 4 dims. Received: {value.ndim}." + ) + + if bias is not None and bias.ndim != 4: raise ValueError( - "Expected `value` to have 4 dims. " f"Received: {value.ndim}." + f"Expected `bias` to have 4 dims. Received: {bias.ndim}." ) + if mask is not None and mask.ndim != 2 and mask.ndim != 4: + raise ValueError( + "Expected `mask` to have either 2 dims or 4 dims. " + f"Received: {mask.ndim}." + ) + + query = convert_to_tensor(query) + key = convert_to_tensor(key) + value = convert_to_tensor(key) + flash_attn_backend = [torch.nn.attention.SDPBackend.FLASH_ATTENTION] - mask = None + attn_mask = None + if is_causal: # We manually create the causal mask here instead of setting # `is_causal=True` in the PyTorch function # because it will not accept attention mask if we did that. - mask = make_causal_mask(query.shape[-2], query.dtype, query.device) + attn_mask = make_causal_mask(query.shape[-2], query.dtype, query.device) - if attn_mask is not None and mask is not None: - attn_mask = convert_to_tensor(attn_mask) - if attn_mask.ndim == 2: - attn_mask = attn_mask.view( - (attn_mask.shape[0], 1, 1, attn_mask.shape[1]) - ) + if mask is not None: + mask = convert_to_tensor(mask) + mask = mask.to(dtype=query.dtype, device=query.device) - attn_mask = attn_mask.masked_fill(attn_mask.logical_not(), -torch.inf) - mask += attn_mask + if mask.ndim == 2: + mask = mask.view((mask.shape[0], 1, 1, mask.shape[1])) + + mask = mask.masked_fill(mask.logical_not(), -torch.inf) + attn_mask = merge_masks(attn_mask, mask) + + if bias is not None: + bias = convert_to_tensor(bias) + bias = bias.to(dtype=query.dtype, device=query.device) + attn_mask = merge_masks(attn_mask, bias) with torch.nn.attention.sdpa_kernel(backends=flash_attn_backend): output = tnn.scaled_dot_product_attention( query=query, key=key, value=value, - attn_mask=mask, + attn_mask=attn_mask, dropout_p=dropout, is_causal=False, scale=scale, ) return output - + def _get_large_negative(dtype): dtype = backend.standardize_dtype(dtype) if dtype == "float16": diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index c0f65dc87cc..7a5223d800f 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2217,3 +2217,140 @@ def dot_product_attention( scale=scale, is_causal=is_causal, ) + + +class FlashAttention(Operation): + def __init__(self): + super().__init__() + + def call( + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + dropout=0.0, + ): + return backend.nn.flash_attention( + query, + key, + value, + bias, + mask, + scale, + is_causal, + dropout, + ) + + def compute_output_spec( + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + dropout=0.0, + ): + return KerasTensor(shape=query.shape, dtype=query.dtype) + + +@keras_export(["keras.ops.flash_attention", "keras.ops.nn.flash_attention"]) +def flash_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + dropout=0.0, +): + """Flash attention function. + + Computes the attention function using Flash attention algorithm + on Q (`query`), K (`key`), and V(`value`): + `attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V`. If we define `logits` + as the output of `Q * K` and the `probs` as the output of `softmax`. + + Throughout this function, we utilize the following notation to represent the + shape of array: + - B: batch size + - S: length of the key/value + - T: length of the query + - N: number of attention heads + - H: dimensions of each attention head + - K: number of key/value heads + - G: number of groups, which equals to `N // K` + + Args: + query: The query array with the shape of `(B, N, T, H)`. + key: The key array with the shape of `(B, N, T, H)`. When `K` equals + `N`, multi-headed attention (MHA) is performed. Otherwise, grouped + query attention (GQA) is performed if `N` is a multiple of `K`. and + multi-query attention (MQA) is performed if `K==1` (a special case + of GQA). + value: The value array with the same shape of `key`. + bias: Optional bias array to be added to logits. The shape must be + broadcastable to `(B, N, T, S)`. + mask: Optional mask array used to filter out logits. It is a boolean + mask where `True` indicates the element should take part in + attention. For an additive mask, users should pass it to bias. The + shape must be broadcastable to `(B, N, T, S)`. + scale: Optional scale for the logits. If `None`, the scale will be set + to `1.0 / sqrt(H)`. + is_causal: Whether to apply causal mask. + + Returns: + An array of the attention output with the same shape of `query`. + + Example: + + >>> query = keras.random.normal((2, 8, 4, 16)) + >>> key = keras.random.normal((2, 8, 6, 16)) + >>> value = keras.random.normal((2, 8, 6, 16)) + >>> keras.ops.nn.dot_product_attention(query, key, value).shape + (2, 8, 4, 16) + """ + framework = backend.backend() + if framework in ["tensorflow", "jax"]: + raise ValueError( + "Flash attention is currently supported in `torch` " + f"backend only. Received: {framework}" + ) + if any_symbolic_tensors( + ( + query, + key, + value, + bias, + mask, + scale, + is_causal, + dropout, + ) + ): + return FlashAttention().symbolic_call( + query=query, + key=key, + value=value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + dropout=dropout, + ) + return backend.nn.flash_attention( + query, + key, + value, + bias, + mask, + scale, + is_causal, + dropout, + ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index fe8d34fc656..e21fed458fd 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -732,6 +732,24 @@ def test_dot_product_attention(self): out = knn.dot_product_attention(query, key, value) self.assertEqual(out.shape, query.shape) + @parameterized.named_parameters( + named_product(mask=(True, False), is_causal=(True, False)) + ) + @pytest.mark.skipif( + backend.backend() in ["tensorflow", "jax"], reason="Not supported yet." + ) + def test_flash_attention(self, mask, is_causal): + num_heads, seqlen, embed_dim = 2, 10, 16 + x = KerasTensor([None, num_heads, seqlen, embed_dim]) + if mask: + attn_mask = KerasTensor([None, num_heads, seqlen, seqlen]) + else: + attn_mask = None + output = ops.flash_attention( + x, x, x, mask=attn_mask, is_causal=is_causal + ) + self.assertEqual(output.shape, x.shape) + class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): @@ -1205,6 +1223,24 @@ def test_dot_product_attention(self): out = knn.dot_product_attention(query, key, value) self.assertEqual(out.shape, query.shape) + @parameterized.named_parameters( + named_product(mask=(True, False), is_causal=(True, False)) + ) + @pytest.mark.skipif( + backend.backend() in ["tensorflow", "jax"], reason="Not supported yet." + ) + def test_flash_attention(self, mask, is_causal): + num_heads, seqlen, embed_dim = 2, 10, 16 + x = KerasTensor([None, num_heads, seqlen, embed_dim]) + if mask: + attn_mask = KerasTensor([None, num_heads, seqlen, seqlen]) + else: + attn_mask = None + output = ops.flash_attention( + x, x, x, mask=attn_mask, is_causal=is_causal + ) + self.assertEqual(output.shape, x.shape) + class NNOpsCorrectnessTest(testing.TestCase): def test_relu(self): @@ -2252,6 +2288,60 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): ) self.assertAllClose(outputs, expected) + @pytest.mark.skipif( + backend.backend() in ["tensorflow", "jax"], reason="Not supported yet." + ) + def test_flash_attention(self): + query_shape = (1, 4, 2, 2) + key_shape = (1, 4, 2, 2) + mask_shape = (1, 4, 2, 2) + query = ( + np.arange(math.prod(query_shape), dtype=float) + .reshape(query_shape) + .astype("float32") + ) + key = ( + np.arange(math.prod(key_shape), dtype=float) + .reshape(key_shape) + .astype("float32") + ) + value = ( + np.arange(math.prod(key_shape), dtype=float) + .reshape(key_shape) + .astype("float32") + ) + + attn_mask = ( + np.arange(1, math.prod(mask_shape) + 1) + .reshape(mask_shape) + .astype("float32") + ) + attn_bias = ( + np.arange(math.prod(mask_shape), dtype=float) + .reshape(mask_shape) + .astype("float32") + ) + + outputs = ops.flash_attention( + query, + key, + value, + bias=attn_bias, + mask=attn_mask, + is_causal=True, + ) + expected_output = np.array( + [ + [ + [[0.0000, 1.0000], [1.9998, 2.9998]], + [[4.0000, 5.0000], [6.0000, 7.0000]], + [[8.0000, 9.0000], [10.0000, 11.0000]], + [[12.0000, 13.0000], [14.0000, 15.0000]], + ] + ] + ) + self.assertAllClose(outputs, expected_output, atol=0.0001) + class NNOpsDtypeTest(testing.TestCase): """Test the dtype to verify that the behavior matches JAX.""" @@ -2631,6 +2721,28 @@ def test_dot_product_attention(self, dtype): expected_dtype, ) + @parameterized.named_parameters( + named_product(dtype=("bfloat16", "float16", "float32")) + ) + @pytest.mark.skipif( + backend.backend() in ["tensorflow", "jax"], + reason="Not supported in tensorflow or jax yet.", + ) + def test_flash_attention(self, dtype): + query = knp.ones((2, 3, 3, 4), dtype=dtype) + key = knp.ones((2, 3, 3, 4), dtype=dtype) + value = knp.ones((2, 3, 3, 4), dtype=dtype) + expected_dtype = dtype + + eager_output = ops.flash_attention(query, key, value) + sym_output = ops.FlashAttention().symbolic_call(query, key, value) + + self.assertDType(eager_output, expected_dtype) + self.assertDType( + sym_output, + expected_dtype, + ) + class NNOpsBehaviorTest(testing.TestCase): def test_logit_recovery_binary_crossentropy(self): From f5858594c4482ba0b329b2977eccccc466faed97 Mon Sep 17 00:00:00 2001 From: Hazem Date: Fri, 4 Oct 2024 14:15:38 +0300 Subject: [PATCH 04/12] added test skipping for flash attention for numpy --- keras/src/ops/nn_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index e21fed458fd..f054d452216 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -736,7 +736,8 @@ def test_dot_product_attention(self): named_product(mask=(True, False), is_causal=(True, False)) ) @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax"], reason="Not supported yet." + backend.backend() in ["tensorflow", "jax", "numpy"], + reason="Not supported.", ) def test_flash_attention(self, mask, is_causal): num_heads, seqlen, embed_dim = 2, 10, 16 @@ -1227,7 +1228,8 @@ def test_dot_product_attention(self): named_product(mask=(True, False), is_causal=(True, False)) ) @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax"], reason="Not supported yet." + backend.backend() in ["tensorflow", "jax", "numpy"], + reason="Not supported.", ) def test_flash_attention(self, mask, is_causal): num_heads, seqlen, embed_dim = 2, 10, 16 @@ -2289,7 +2291,8 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): self.assertAllClose(outputs, expected) @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax"], reason="Not supported yet." + backend.backend() in ["tensorflow", "jax", "numpy"], + reason="Not supported.", ) def test_flash_attention(self): query_shape = (1, 4, 2, 2) @@ -2725,8 +2728,8 @@ def test_dot_product_attention(self, dtype): named_product(dtype=("bfloat16", "float16", "float32")) ) @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax"], - reason="Not supported in tensorflow or jax yet.", + backend.backend() in ["tensorflow", "jax", "numpy"], + reason="Not supported.", ) def test_flash_attention(self, dtype): query = knp.ones((2, 3, 3, 4), dtype=dtype) From e614e75b634bda37593b29c3901714bcd2a9e5ba Mon Sep 17 00:00:00 2001 From: Hazem Date: Sun, 6 Oct 2024 18:03:33 +0300 Subject: [PATCH 05/12] removed flash attn op and added support for flash attention inside dot_product_attention op --- keras/api/_tf_keras/keras/ops/__init__.py | 1 - keras/api/_tf_keras/keras/ops/nn/__init__.py | 1 - keras/api/ops/__init__.py | 1 - keras/api/ops/nn/__init__.py | 1 - keras/src/backend/jax/nn.py | 18 +- keras/src/backend/numpy/nn.py | 11 +- keras/src/backend/tensorflow/nn.py | 14 +- keras/src/backend/torch/nn.py | 119 +++---------- keras/src/ops/nn.py | 169 ++++--------------- keras/src/ops/nn_test.py | 121 +------------ 10 files changed, 98 insertions(+), 358 deletions(-) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 382cd129cde..20cf46889d2 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -69,7 +69,6 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu -from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index f7e2fc2a2f9..adce3312860 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -15,7 +15,6 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu -from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 382cd129cde..20cf46889d2 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -69,7 +69,6 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu -from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index f7e2fc2a2f9..adce3312860 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -15,7 +15,6 @@ from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu -from keras.src.ops.nn import flash_attention from keras.src.ops.nn import gelu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index b549b3517e2..cba73918976 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -986,7 +986,14 @@ def _dot_product_attention_core( def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): query = convert_to_tensor(query) key = convert_to_tensor(key) @@ -1000,6 +1007,7 @@ def dot_product_attention( # `dot_product_attention` is only available in jax>=0.4.31 if hasattr(jax.nn, "dot_product_attention"): + implementation = "cudnn" if flash_attention else "xla" return jax.nn.dot_product_attention( query, key, @@ -1008,6 +1016,14 @@ def dot_product_attention( mask=mask, scale=scale, is_causal=is_causal, + implementation=implementation, + ) + + if flash_attention: + raise ValueError( + "Flash attention is not supported in your " + "current JAX version. Please update it " + "using `pip install -U jax jaxlib`." ) # Ref: jax.nn.dot_product_attention diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index f3e02d6d5a9..eea127e554a 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -1033,8 +1033,17 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): + if flash_attention: + raise ValueError("Flash attention is not implemented in NumPy.") # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 # Not support `query_seq_lengths` and `key_value_seq_lengths` args diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index bc7c1e61486..01a1aca26d0 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -964,8 +964,20 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): + if flash_attention: + raise ValueError( + "Flash attention is not supported yet in TensorFlow backend." + ) + # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 # Not support `query_seq_lengths` and `key_value_seq_lengths` args diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 680555ba601..cd61b3844e2 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -855,97 +855,6 @@ def psnr(x1, x2, max_val): return psnr -def make_causal_mask(size, dtype, device): - # 1 is subtracted to invert the lower triangle with the upper one. - mask = 1 - torch.tril(torch.ones((size, size), dtype=dtype, device=device)) - mask = mask.masked_fill(mask.bool(), -torch.inf) - mask = mask.view((1, 1, size, size)) - return mask - - -def merge_masks(attn_mask, input_mask): - if attn_mask is None: - return input_mask - return attn_mask + input_mask - - -def flash_attention( - query, - key, - value, - bias=None, - mask=None, - scale=None, - is_causal=False, - dropout=0.0, -): - if query.ndim != 4: - raise ValueError( - "Expected `query` to have 4 dims. " f"Received: {query.ndim}." - ) - - if key.ndim != 4: - raise ValueError( - "Expected `key` to have 4 dims. " f"Received: {key.ndim}." - ) - - if value.ndim != 4: - raise ValueError( - f"Expected `value` to have 4 dims. Received: {value.ndim}." - ) - - if bias is not None and bias.ndim != 4: - raise ValueError( - f"Expected `bias` to have 4 dims. Received: {bias.ndim}." - ) - - if mask is not None and mask.ndim != 2 and mask.ndim != 4: - raise ValueError( - "Expected `mask` to have either 2 dims or 4 dims. " - f"Received: {mask.ndim}." - ) - - query = convert_to_tensor(query) - key = convert_to_tensor(key) - value = convert_to_tensor(key) - - flash_attn_backend = [torch.nn.attention.SDPBackend.FLASH_ATTENTION] - attn_mask = None - - if is_causal: - # We manually create the causal mask here instead of setting - # `is_causal=True` in the PyTorch function - # because it will not accept attention mask if we did that. - attn_mask = make_causal_mask(query.shape[-2], query.dtype, query.device) - - if mask is not None: - mask = convert_to_tensor(mask) - mask = mask.to(dtype=query.dtype, device=query.device) - - if mask.ndim == 2: - mask = mask.view((mask.shape[0], 1, 1, mask.shape[1])) - - mask = mask.masked_fill(mask.logical_not(), -torch.inf) - attn_mask = merge_masks(attn_mask, mask) - - if bias is not None: - bias = convert_to_tensor(bias) - bias = bias.to(dtype=query.dtype, device=query.device) - attn_mask = merge_masks(attn_mask, bias) - - with torch.nn.attention.sdpa_kernel(backends=flash_attn_backend): - output = tnn.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout, - is_causal=False, - scale=scale, - ) - return output - - def _get_large_negative(dtype): dtype = backend.standardize_dtype(dtype) if dtype == "float16": @@ -956,7 +865,14 @@ def _get_large_negative(dtype): def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): if bias is not None: raise ValueError( @@ -982,7 +898,20 @@ def dot_product_attention( query = torch.transpose(query, axis0, axis1) key = torch.transpose(key, axis0, axis1) value = torch.transpose(value, axis0, axis1) - attention_output = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale - ) + if flash_attention: + with torch.nn.attention.sdpa_kernel( + backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION], + ): + attention_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=mask, + is_causal=is_causal, + scale=scale, + ) + else: + attention_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale + ) return torch.transpose(attention_output, axis1, axis0) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 7a5223d800f..2d779582a5b 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2131,7 +2131,16 @@ def __init__(self, is_causal=False): super().__init__() self.is_causal = is_causal - def call(self, query, key, value, bias=None, mask=None, scale=None): + def call( + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + flash_attention=False, + ): return backend.nn.dot_product_attention( query, key, @@ -2140,10 +2149,18 @@ def call(self, query, key, value, bias=None, mask=None, scale=None): mask=mask, scale=scale, is_causal=self.is_causal, + flash_attention=flash_attention, ) def compute_output_spec( - self, query, key, value, bias=None, mask=None, scale=None + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + flash_attention=False, ): return KerasTensor(query.shape, dtype=query.dtype) @@ -2152,7 +2169,14 @@ def compute_output_spec( ["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"] ) def dot_product_attention( - query, key, value, bias=None, mask=None, scale=None, is_causal=False + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=False, ): """Scaled dot product attention function. @@ -2207,6 +2231,7 @@ def dot_product_attention( bias=bias, mask=mask, scale=scale, + flash_attention=flash_attention, ) return backend.nn.dot_product_attention( query, @@ -2216,141 +2241,5 @@ def dot_product_attention( mask=mask, scale=scale, is_causal=is_causal, - ) - - -class FlashAttention(Operation): - def __init__(self): - super().__init__() - - def call( - self, - query, - key, - value, - bias=None, - mask=None, - scale=None, - is_causal=False, - dropout=0.0, - ): - return backend.nn.flash_attention( - query, - key, - value, - bias, - mask, - scale, - is_causal, - dropout, - ) - - def compute_output_spec( - self, - query, - key, - value, - bias=None, - mask=None, - scale=None, - is_causal=False, - dropout=0.0, - ): - return KerasTensor(shape=query.shape, dtype=query.dtype) - - -@keras_export(["keras.ops.flash_attention", "keras.ops.nn.flash_attention"]) -def flash_attention( - query, - key, - value, - bias=None, - mask=None, - scale=None, - is_causal=False, - dropout=0.0, -): - """Flash attention function. - - Computes the attention function using Flash attention algorithm - on Q (`query`), K (`key`), and V(`value`): - `attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V`. If we define `logits` - as the output of `Q * K` and the `probs` as the output of `softmax`. - - Throughout this function, we utilize the following notation to represent the - shape of array: - - B: batch size - - S: length of the key/value - - T: length of the query - - N: number of attention heads - - H: dimensions of each attention head - - K: number of key/value heads - - G: number of groups, which equals to `N // K` - - Args: - query: The query array with the shape of `(B, N, T, H)`. - key: The key array with the shape of `(B, N, T, H)`. When `K` equals - `N`, multi-headed attention (MHA) is performed. Otherwise, grouped - query attention (GQA) is performed if `N` is a multiple of `K`. and - multi-query attention (MQA) is performed if `K==1` (a special case - of GQA). - value: The value array with the same shape of `key`. - bias: Optional bias array to be added to logits. The shape must be - broadcastable to `(B, N, T, S)`. - mask: Optional mask array used to filter out logits. It is a boolean - mask where `True` indicates the element should take part in - attention. For an additive mask, users should pass it to bias. The - shape must be broadcastable to `(B, N, T, S)`. - scale: Optional scale for the logits. If `None`, the scale will be set - to `1.0 / sqrt(H)`. - is_causal: Whether to apply causal mask. - - Returns: - An array of the attention output with the same shape of `query`. - - Example: - - >>> query = keras.random.normal((2, 8, 4, 16)) - >>> key = keras.random.normal((2, 8, 6, 16)) - >>> value = keras.random.normal((2, 8, 6, 16)) - >>> keras.ops.nn.dot_product_attention(query, key, value).shape - (2, 8, 4, 16) - """ - framework = backend.backend() - if framework in ["tensorflow", "jax"]: - raise ValueError( - "Flash attention is currently supported in `torch` " - f"backend only. Received: {framework}" - ) - if any_symbolic_tensors( - ( - query, - key, - value, - bias, - mask, - scale, - is_causal, - dropout, - ) - ): - return FlashAttention().symbolic_call( - query=query, - key=key, - value=value, - bias=bias, - mask=mask, - scale=scale, - is_causal=is_causal, - dropout=dropout, - ) - return backend.nn.flash_attention( - query, - key, - value, - bias, - mask, - scale, - is_causal, - dropout, + flash_attention=flash_attention, ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index f054d452216..5a8ba34eba3 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -732,25 +732,6 @@ def test_dot_product_attention(self): out = knn.dot_product_attention(query, key, value) self.assertEqual(out.shape, query.shape) - @parameterized.named_parameters( - named_product(mask=(True, False), is_causal=(True, False)) - ) - @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax", "numpy"], - reason="Not supported.", - ) - def test_flash_attention(self, mask, is_causal): - num_heads, seqlen, embed_dim = 2, 10, 16 - x = KerasTensor([None, num_heads, seqlen, embed_dim]) - if mask: - attn_mask = KerasTensor([None, num_heads, seqlen, seqlen]) - else: - attn_mask = None - output = ops.flash_attention( - x, x, x, mask=attn_mask, is_causal=is_causal - ) - self.assertEqual(output.shape, x.shape) - class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): @@ -1224,25 +1205,6 @@ def test_dot_product_attention(self): out = knn.dot_product_attention(query, key, value) self.assertEqual(out.shape, query.shape) - @parameterized.named_parameters( - named_product(mask=(True, False), is_causal=(True, False)) - ) - @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax", "numpy"], - reason="Not supported.", - ) - def test_flash_attention(self, mask, is_causal): - num_heads, seqlen, embed_dim = 2, 10, 16 - x = KerasTensor([None, num_heads, seqlen, embed_dim]) - if mask: - attn_mask = KerasTensor([None, num_heads, seqlen, seqlen]) - else: - attn_mask = None - output = ops.flash_attention( - x, x, x, mask=attn_mask, is_causal=is_causal - ) - self.assertEqual(output.shape, x.shape) - class NNOpsCorrectnessTest(testing.TestCase): def test_relu(self): @@ -2246,9 +2208,12 @@ def test_psnr(self): bias=(None, True), scale=(None, 1.0), mask_and_is_causal=((None, False), (True, False), (None, True)), + flash_attention=(True, False), ) ) - def test_dot_product_attention(self, bias, scale, mask_and_is_causal): + def test_dot_product_attention( + self, bias, scale, mask_and_is_causal, flash_attention + ): mask, is_causal = mask_and_is_causal query_shape = (2, 3, 4, 5) key_shape = (2, 6, 4, 5) @@ -2287,64 +2252,10 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask=mask, scale=scale, is_causal=is_causal, + flash_attention=flash_attention, ) self.assertAllClose(outputs, expected) - @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax", "numpy"], - reason="Not supported.", - ) - def test_flash_attention(self): - query_shape = (1, 4, 2, 2) - key_shape = (1, 4, 2, 2) - mask_shape = (1, 4, 2, 2) - query = ( - np.arange(math.prod(query_shape), dtype=float) - .reshape(query_shape) - .astype("float32") - ) - key = ( - np.arange(math.prod(key_shape), dtype=float) - .reshape(key_shape) - .astype("float32") - ) - value = ( - np.arange(math.prod(key_shape), dtype=float) - .reshape(key_shape) - .astype("float32") - ) - - attn_mask = ( - np.arange(1, math.prod(mask_shape) + 1) - .reshape(mask_shape) - .astype("float32") - ) - attn_bias = ( - np.arange(math.prod(mask_shape), dtype=float) - .reshape(mask_shape) - .astype("float32") - ) - - outputs = ops.flash_attention( - query, - key, - value, - bias=attn_bias, - mask=attn_mask, - is_causal=True, - ) - expected_output = np.array( - [ - [ - [[0.0000, 1.0000], [1.9998, 2.9998]], - [[4.0000, 5.0000], [6.0000, 7.0000]], - [[8.0000, 9.0000], [10.0000, 11.0000]], - [[12.0000, 13.0000], [14.0000, 15.0000]], - ] - ] - ) - self.assertAllClose(outputs, expected_output, atol=0.0001) - class NNOpsDtypeTest(testing.TestCase): """Test the dtype to verify that the behavior matches JAX.""" @@ -2724,28 +2635,6 @@ def test_dot_product_attention(self, dtype): expected_dtype, ) - @parameterized.named_parameters( - named_product(dtype=("bfloat16", "float16", "float32")) - ) - @pytest.mark.skipif( - backend.backend() in ["tensorflow", "jax", "numpy"], - reason="Not supported.", - ) - def test_flash_attention(self, dtype): - query = knp.ones((2, 3, 3, 4), dtype=dtype) - key = knp.ones((2, 3, 3, 4), dtype=dtype) - value = knp.ones((2, 3, 3, 4), dtype=dtype) - expected_dtype = dtype - - eager_output = ops.flash_attention(query, key, value) - sym_output = ops.FlashAttention().symbolic_call(query, key, value) - - self.assertDType(eager_output, expected_dtype) - self.assertDType( - sym_output, - expected_dtype, - ) - class NNOpsBehaviorTest(testing.TestCase): def test_logit_recovery_binary_crossentropy(self): From 561a9a787e440ca01ff962add95ef491e40b8bc6 Mon Sep 17 00:00:00 2001 From: Hazem Date: Sun, 6 Oct 2024 18:24:26 +0300 Subject: [PATCH 06/12] added skip tests for every framework except torch --- keras/src/ops/nn_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 5a8ba34eba3..eed141f7a5c 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2235,6 +2235,11 @@ def test_dot_product_attention( mask_shape ) + if flash_attention and backend.backend() in ["tensorflow", "numpy", "jax"]: + self.skipTest( + "flash attention is not supported in tensorflow and numpy." + ) + expected = _dot_product_attention( query, key, From 86db62ac444ed628c7f104d1fae4180489867f43 Mon Sep 17 00:00:00 2001 From: Hazem Date: Sun, 6 Oct 2024 18:25:17 +0300 Subject: [PATCH 07/12] formatted files --- keras/src/ops/nn_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index eed141f7a5c..68050ed2b25 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2235,7 +2235,11 @@ def test_dot_product_attention( mask_shape ) - if flash_attention and backend.backend() in ["tensorflow", "numpy", "jax"]: + if flash_attention and backend.backend() in [ + "tensorflow", + "numpy", + "jax", + ]: self.skipTest( "flash attention is not supported in tensorflow and numpy." ) From 43938caf9def5e6dbe9e5ca58fc6a6a0718a148f Mon Sep 17 00:00:00 2001 From: Hazem Date: Mon, 7 Oct 2024 14:40:23 +0300 Subject: [PATCH 08/12] added checks for flash attention in pytorch beforing computing attention and removed flash attention from tests --- keras/src/backend/torch/nn.py | 28 ++++++++++++++++++++++++++++ keras/src/ops/nn_test.py | 15 +-------------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index cd61b3844e2..e4291f6b84c 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -864,6 +864,19 @@ def _get_large_negative(dtype): return convert_to_tensor(val * -0.7, dtype=dtype) +def is_flash_attention_enabled(query, key, value, mask=None, is_causal=False): + params = torch.backends.cuda.SDPAParams( + query, + key, + value, + mask, + 0.0, + is_causal, + ) + is_enabled = torch.backends.cuda.can_use_flash_attention(params, False) + return is_enabled + + def dot_product_attention( query, key, @@ -898,7 +911,22 @@ def dot_product_attention( query = torch.transpose(query, axis0, axis1) key = torch.transpose(key, axis0, axis1) value = torch.transpose(value, axis0, axis1) + if flash_attention: + is_enabled = is_flash_attention_enabled( + query=query, + key=key, + value=value, + mask=mask, + is_causal=is_causal, + ) + if not is_enabled: + raise ValueError( + "Flash attention is not enabled in `torch` backend. " + "The dtype of the inputs should be float16/bfloat16 " + "and your GPU should support flash attention implementation." + ) + with torch.nn.attention.sdpa_kernel( backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION], ): diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 68050ed2b25..fe8d34fc656 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2208,12 +2208,9 @@ def test_psnr(self): bias=(None, True), scale=(None, 1.0), mask_and_is_causal=((None, False), (True, False), (None, True)), - flash_attention=(True, False), ) ) - def test_dot_product_attention( - self, bias, scale, mask_and_is_causal, flash_attention - ): + def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask, is_causal = mask_and_is_causal query_shape = (2, 3, 4, 5) key_shape = (2, 6, 4, 5) @@ -2235,15 +2232,6 @@ def test_dot_product_attention( mask_shape ) - if flash_attention and backend.backend() in [ - "tensorflow", - "numpy", - "jax", - ]: - self.skipTest( - "flash attention is not supported in tensorflow and numpy." - ) - expected = _dot_product_attention( query, key, @@ -2261,7 +2249,6 @@ def test_dot_product_attention( mask=mask, scale=scale, is_causal=is_causal, - flash_attention=flash_attention, ) self.assertAllClose(outputs, expected) From 738375f885f50c3bbc9c74b332deedae04957935 Mon Sep 17 00:00:00 2001 From: Hazem Date: Mon, 7 Oct 2024 18:56:50 +0300 Subject: [PATCH 09/12] added skipping tests for all frameworks except jax --- keras/src/ops/nn_test.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index fe8d34fc656..114f997c44b 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2208,9 +2208,12 @@ def test_psnr(self): bias=(None, True), scale=(None, 1.0), mask_and_is_causal=((None, False), (True, False), (None, True)), + flash_attention=(True, False), ) ) - def test_dot_product_attention(self, bias, scale, mask_and_is_causal): + def test_dot_product_attention( + self, bias, scale, mask_and_is_causal, flash_attention + ): mask, is_causal = mask_and_is_causal query_shape = (2, 3, 4, 5) key_shape = (2, 6, 4, 5) @@ -2232,6 +2235,10 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask_shape ) + if flash_attention and backend.backend() in ["torch", "tensorflow", "numpy"]: + self.skipTest("Not supported in TF and NumPy and supported for " + "PyTorch with specific requirements.") + expected = _dot_product_attention( query, key, @@ -2249,6 +2256,7 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask=mask, scale=scale, is_causal=is_causal, + flash_attention=flash_attention, ) self.assertAllClose(outputs, expected) From e12950c88119a5623c7c5b738e96969c9dc02de0 Mon Sep 17 00:00:00 2001 From: Hazem Date: Mon, 7 Oct 2024 18:59:35 +0300 Subject: [PATCH 10/12] formatted files --- keras/src/ops/nn_test.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 114f997c44b..3e33781108d 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2235,9 +2235,15 @@ def test_dot_product_attention( mask_shape ) - if flash_attention and backend.backend() in ["torch", "tensorflow", "numpy"]: - self.skipTest("Not supported in TF and NumPy and supported for " - "PyTorch with specific requirements.") + if flash_attention and backend.backend() in [ + "torch", + "tensorflow", + "numpy", + ]: + self.skipTest( + "Not supported in TF and NumPy and supported for " + "PyTorch with specific requirements." + ) expected = _dot_product_attention( query, From da50a6ba60aabe7a3469df8ef40e4e8471085b32 Mon Sep 17 00:00:00 2001 From: Hazem Date: Mon, 7 Oct 2024 20:20:20 +0300 Subject: [PATCH 11/12] added conditions to skip tests for jax --- keras/src/ops/nn_test.py | 51 ++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 3e33781108d..f0a402a3afd 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2245,6 +2245,47 @@ def test_dot_product_attention( "PyTorch with specific requirements." ) + if flash_attention and backend.backend() == "jax": + try: + outputs = knn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + flash_attention=flash_attention, + ) + except ValueError as e: + if e.args[0].startswith( + "Flash attention is not supported in your " + "current JAX version" + ): + self.skipTest( + "JAX version is does not have " + "`dot_product_attention` function." + ) + except RuntimeError as e: + if e.args[0] == "cuDNN is not detected.": + self.skipTest("No CuDNN to run flash attention for JAX.") + elif e.args[0] == "Require at least Ampere arch to run": + self.skipTest( + "Requires at least Ampere arch to run flash attention " + "for JAX." + ) + else: + outputs = knn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + flash_attention=flash_attention, + ) + expected = _dot_product_attention( query, key, @@ -2254,16 +2295,6 @@ def test_dot_product_attention( scale=scale, is_causal=is_causal, ) - outputs = knn.dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - scale=scale, - is_causal=is_causal, - flash_attention=flash_attention, - ) self.assertAllClose(outputs, expected) From 57e6e5601013a681db6f8d631abe28769fb84611 Mon Sep 17 00:00:00 2001 From: Hazem Date: Mon, 7 Oct 2024 21:18:14 +0300 Subject: [PATCH 12/12] fixed typo --- keras/src/ops/nn_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index f0a402a3afd..4d75760d894 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2263,7 +2263,7 @@ def test_dot_product_attention( "current JAX version" ): self.skipTest( - "JAX version is does not have " + "JAX version does not have " "`dot_product_attention` function." ) except RuntimeError as e: