Skip to content

Commit

Permalink
Enable flash attention (#20448)
Browse files Browse the repository at this point in the history
* Enable flash attention

* code reformat

* address review comments

* add docstring

* update docstring

* add numerical correctness test

* code reformat

* use causal mask from call method

* address review comments

* update if

* fix tests

* update tests

* enable flash attention on TPU JAX

* update code

* minor fix

* address review comments

* fix tests

* run api_gen

* code reformat

* fix mask issue

* disable causal mask in dpa because it is comuted in comput_attention_mask

* fix masks tests

* code reformat

* disable tests of env is not supported

* fix code reformat error

* fix torch GPU tests

* fix torch gpu tests

* make everything contigious

* check if mask is not before callng contigious

* disable pytorch GPU test

* merge master

* code reformat

* set bias to None

* disable GPU test
  • Loading branch information
divyashreepathihalli authored Nov 7, 2024
1 parent 30a6b87 commit 5bf4ac7
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 21 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from keras.src.backend.config import set_image_data_format
from keras.src.dtype_policies.dtype_policy import dtype_policy
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
from keras.src.layers.attention.attention import disable_flash_attention
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.saving.serialization_lib import enable_unsafe_deserialization
from keras.src.utils.backend_utils import set_backend
from keras.src.utils.io_utils import disable_interactive_logging
Expand Down
3 changes: 3 additions & 0 deletions keras/api/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from keras.src.backend.config import set_image_data_format
from keras.src.dtype_policies.dtype_policy import dtype_policy
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
from keras.src.layers.attention.attention import disable_flash_attention
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.saving.serialization_lib import enable_unsafe_deserialization
from keras.src.utils.backend_utils import set_backend
from keras.src.utils.io_utils import disable_interactive_logging
Expand Down
17 changes: 15 additions & 2 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import jax.numpy as jnp
from jax import lax
from jax import nn as jnn
from jax.experimental.pallas.ops.tpu import (
flash_attention as flash_attention_tpu,
)

from keras.src import backend
from keras.src.backend.common.backend_utils import (
Expand Down Expand Up @@ -1019,7 +1022,18 @@ def dot_product_attention(
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)

is_tpu = jax.devices()[0].platform == "tpu"
if is_tpu and flash_attention:
# Use TPU-optimized flash attention from Pallas
return flash_attention_tpu(
query,
key,
value,
ab=bias,
segment_ids=mask,
causal=is_causal,
sm_scale=scale,
)
# `dot_product_attention` is only available in jax>=0.4.31
if hasattr(jax.nn, "dot_product_attention"):
implementation = "cudnn" if flash_attention else "xla"
Expand All @@ -1040,7 +1054,6 @@ def dot_product_attention(
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
)

# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
Expand Down
9 changes: 8 additions & 1 deletion keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,14 @@ def dot_product_attention(
scale=scale,
)
else:
if mask is not None:
mask = mask.contiguous()
attention_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=mask,
is_causal=is_causal,
scale=scale,
)
return torch.transpose(attention_output, axis1, axis0)
47 changes: 47 additions & 0 deletions keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
from keras.src.layers.layer import Layer


Expand Down Expand Up @@ -282,3 +283,49 @@ def get_config(self):
"dropout": self.dropout,
}
return {**base_config, **config}


@keras_export("keras.config.enable_flash_attention")
def enable_flash_attention():
"""Enable flash attention.
Flash attention offers performance optimization for attention layers,
making it especially useful for large language models (LLMs) that
benefit from faster and more memory-efficient attention computations.
Once enabled, supported layers like `MultiHeadAttention` will
use flash attention for faster computations.
"""
global_state.set_global_attribute("flash_attention", True)


@keras_export("keras.config.disable_flash_attention")
def disable_flash_attention():
"""Disable flash attention.
Flash attention offers performance optimization for attention layers,
making it especially useful for large language models (LLMs) that
benefit from faster and more memory-efficient attention computations.
Once disabled, supported layers like `MultiHeadAttention` will not
use flash attention for faster computations.
"""
global_state.set_global_attribute("flash_attention", False)


@keras_export("keras.config.is_flash_attention_enabled")
def is_flash_attention_enabled():
"""Checks whether flash attention is globally enabled in Keras.
Flash attention is a performance-optimized method for computing attention
in large models, such as transformers, allowing for faster and more
memory-efficient operations. This function checks the global Keras
configuration to determine if flash attention is enabled for compatible
layers (e.g., `MultiHeadAttention`).
Returns:
bool or None: Returns `True` if flash attention is enabled,
`False` if it is disabled, and `None` if the global
setting has not been defined.
"""
return global_state.get_global_attribute("flash_attention", default=None)
86 changes: 75 additions & 11 deletions keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src import regularizers
from keras.src.api_export import keras_export
from keras.src.layers.activations.softmax import Softmax
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.layers.core.einsum_dense import EinsumDense
from keras.src.layers.layer import Layer
from keras.src.layers.regularization.dropout import Dropout
Expand Down Expand Up @@ -52,6 +53,9 @@ class MultiHeadAttention(Layer):
feature dim (the query input's last dimension).
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
flash_attention: If unspecified, defaults to the global flash attention
configuration setting (which can be set via
`keras.config.enable_flash_attention().
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
Expand Down Expand Up @@ -104,6 +108,7 @@ def __init__(
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
Expand Down Expand Up @@ -131,6 +136,8 @@ def __init__(
self._activity_regularizer = regularizers.get(activity_regularizer)
self._kernel_constraint = constraints.get(kernel_constraint)
self._bias_constraint = constraints.get(bias_constraint)
self._flash_attention = flash_attention or is_flash_attention_enabled()

if isinstance(attention_axes, int):
attention_axes = (attention_axes,)
elif attention_axes and not isinstance(attention_axes, (list, tuple)):
Expand Down Expand Up @@ -392,7 +399,13 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
return self._softmax(attention_scores, mask=attention_mask)

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
self,
query,
key,
value,
return_attention_scores,
attention_mask=None,
training=None,
):
"""Applies Dot-product attention with query, key, value tensors.
Expand All @@ -415,9 +428,57 @@ def _compute_attention(
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.

# Check for flash attention constraints
if self._flash_attention and return_attention_scores:
raise ValueError(
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
" attention scores."
)
if self._flash_attention and self._dropout > 0.0:
raise ValueError(
"Dropout is not supported when flash "
"attention is enabled. Please set dropout to 0.0 to use "
"flash attention."
)

# Determine whether to use dot-product attention
use_dot_product_attention = not (
self._dropout > 0.0
or return_attention_scores
or (len(query.shape) != 4)
)

if use_dot_product_attention:
if attention_mask is not None:
# Ensure attention_mask has the correct shape for broadcasting
# Expected shape: [batch_size, num_heads, query_seq_len,
# key_seq_len]. This is because masked_softmax is not supported
# in JAX.
while len(attention_mask.shape) < 4:
attention_mask = ops.expand_dims(
attention_mask, axis=1
) # Add dimension for num_heads
if attention_mask.shape[1] != self._num_heads:
attention_mask = ops.tile(
attention_mask, [1, self._num_heads, 1, 1]
)
# Directly compute the attention output using dot-product attention
attention_output = ops.dot_product_attention(
query=query,
key=key,
value=value,
bias=None,
mask=attention_mask,
scale=self._inverse_sqrt_key_dim,
is_causal=False,
flash_attention=self._flash_attention,
)
return attention_output, None

# Default behavior without flash attention, with explicit attention
# scores
query = ops.multiply(
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
)
Expand All @@ -426,13 +487,13 @@ def _compute_attention(
# attention scores.
attention_scores = ops.einsum(self._dot_product_equation, key, query)

# Apply the mask using the custom masked softmax
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if self.dropout:
# Apply dropout to the attention scores if needed
if self._dropout > 0.0:
final_attn_scores = self._dropout_layer(
attention_scores, training=training
)
Expand Down Expand Up @@ -460,7 +521,6 @@ def call(
):
if key is None:
key = value

attention_mask = self._compute_attention_mask(
query,
value,
Expand All @@ -470,9 +530,9 @@ def call(
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

# N = `num_attention_heads`
# H = `size_per_head`

# `query` = [B, T, N ,H]
query = self._query_dense.call(query)

Expand All @@ -481,9 +541,13 @@ def call(

# `value` = [B, S, N, H]
value = self._value_dense.call(value)

attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training
query,
key,
value,
return_attention_scores,
attention_mask,
training,
)
attention_output = self._output_dense.call(attention_output)

Expand Down
Loading

0 comments on commit 5bf4ac7

Please sign in to comment.