Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster flash attention bwd implementation #177

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 171 additions & 3 deletions jax_triton/pallas/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def body(start_k, carry):
"interpret", "debug"])
def mha(q, k, v,
sm_scale: float = 1.0,
causal: bool = False,
causal: bool = True,
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
block_q: int = 128,
block_k: int = 128,
backward_pass_impl: str = "triton",
Expand Down Expand Up @@ -292,6 +292,109 @@ def inner_loop(start_q, carry):
slice(None)), dk.astype(dk_ref.dtype))
lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None)

def mha_backward_kernel_dq(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref, _,
# Outputs
dq_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]

start_q = pl.program_id(0)
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
span_q = start_q * block_q + jnp.arange(block_q)
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), eviction_policy="evict_last")
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved

def inner_loop(start_k, carry):
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
dq = carry
span_k = start_k * block_k + jnp.arange(block_k)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))

qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
if sm_scale != 1.0:
qk *= sm_scale
if causal:
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
p = jnp.exp(qk - m[:, None])
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
return dq
if causal:
upper_bound = lax.div(start_q * block_q, block_k) + 1
else:
upper_bound = jt.cdiv(seq_len, block_k)
dq = lax.fori_loop(0, upper_bound, inner_loop, dq)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need eviction policy here


def mha_backward_kernel_dkv(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref,
# Outputs
dk_ref, dv_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]
start_k = pl.program_id(0)

dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)

def inner_loop(start_q, carry):
dv, dk = carry
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
if sm_scale != 1.0:
qk *= sm_scale
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
p = jnp.exp(qk - m[:, None])
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
dv = dv + pl.dot(p.astype(do.dtype).T, do)
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
return dv, dk
if causal:
lower_bound = lax.div(start_k * block_k, block_q)
else:
lower_bound = 0
dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop,
(dv, dk))
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
Comment on lines +394 to +396
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: indentation


def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
num_stages: int, grid: Any, interpret: bool,
Expand Down Expand Up @@ -346,14 +449,79 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
elif backward_pass_impl == "triton_split":
# We accumulate into dq so we need to initialize it to zeros.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is not accurate here

dq = jnp.zeros(q.shape, jnp.float32)
out_shapes_q = jax.ShapeDtypeStruct(dq.shape, dq.dtype)

grid_q = (jt.cdiv(seq_len, block_q), batch_size, num_heads)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
# grid_q = (batch_size, num_heads)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
num_warps = 4
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
dq = pl.pallas_call(
functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid_q,
out_shape=out_shapes_q,
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved

grid_kv = (jt.cdiv(seq_len, block_k), batch_size, num_heads)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
out_shapes_kv = [
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel_dkv, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid_kv,
out_shape=out_shapes_kv,
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
out_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1)(q, k, v, out, do_scaled, l, m, delta)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved

else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv
mha.defvjp(_mha_forward, _mha_backward)


@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False):
def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True):
q_seq_len = q.shape[1]
kv_seq_len = k.shape[1]
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
Expand All @@ -362,4 +530,4 @@ def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False):
mask = jnp.broadcast_to(mask, logits.shape)
logits = jnp.where(mask, logits, float('-inf'))
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved