Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster flash attention bwd implementation #177

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 170 additions & 7 deletions jax_triton/pallas/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int,
def _preprocess_backward_kernel(out_ref, dout_ref, l_ref,
new_dout_ref, delta_ref, *,
block_q: int):
pid_m = pl.program_id(0)
pid_m = pl.program_id(2)

off_m = pl.ds(pid_m * block_q, block_q)
# load
Expand All @@ -214,15 +214,15 @@ def _preprocess_backward(out, do, l, block_q: int,
]
do_scaled, delta = pl.pallas_call(
functools.partial(_preprocess_backward_kernel, block_q=block_q),
grid=(jt.cdiv(seq_len, block_q), batch_size, num_heads),
grid=(batch_size, num_heads, jt.cdiv(seq_len, block_q)),
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename to be (i, j, _)? Same below?

pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
],
out_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
],
num_warps=4,
num_stages=3,
Expand All @@ -232,6 +232,7 @@ def _preprocess_backward(out, do, l, block_q: int,
name="mha_preprocess_backward")(out, do, l)
return do_scaled, delta


def mha_backward_kernel(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
Expand Down Expand Up @@ -292,6 +293,109 @@ def inner_loop(start_q, carry):
slice(None)), dk.astype(dk_ref.dtype))
lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None)


def mha_backward_kernel_dq(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref,
# Outputs
dq_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]
start_q = pl.program_id(2)
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
span_q = start_q * block_q + jnp.arange(block_q)
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dq = jnp.zeros([block_q, block_d], dtype=jnp.float32)

def inner_loop(start_k, dq):
span_k = start_k * block_k + jnp.arange(block_k)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))

qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
if sm_scale != 1.0:
qk *= sm_scale
if causal:
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
p = jnp.exp(qk - m[:, None])
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
return dq
if causal:
upper_bound = lax.div(start_q * block_q, block_k) + 1
else:
upper_bound = jt.cdiv(seq_len, block_k)
dq = lax.fori_loop(0, upper_bound, inner_loop, dq)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need eviction policy here



def mha_backward_kernel_dkv(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref,
# Outputs
dk_ref, dv_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]
start_k = pl.program_id(2)

dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)

def inner_loop(start_q, carry):
dv, dk = carry
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
if sm_scale != 1.0:
qk *= sm_scale
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
p = jnp.exp(qk - m[:, None])
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
dv = dv + pl.dot(p.astype(do.dtype).T, do)
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
return dv, dk
if causal:
lower_bound = lax.div(start_k * block_k, block_q)
else:
lower_bound = 0
dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop,
(dv, dk))
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
Comment on lines +394 to +396
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: indentation



def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
num_stages: int, grid: Any, interpret: bool,
Expand Down Expand Up @@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
elif backward_pass_impl == "triton_split":
# We accumulate into dq so we need to initialize it to zeros.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is not accurate here

out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we don't need dq to be f32 anymore. Could you try q.dtype?


grid_q = (batch_size, num_heads, jt.cdiv(seq_len, block_q))
num_warps = 8
dq = pl.pallas_call(
functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid_q,
out_shape=out_shapes_q,
in_specs=[
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
],
out_specs=[
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward_q",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=2)(q, k, v, out, do_scaled, l, m, delta)

grid_kv = (batch_size, num_heads, jt.cdiv(seq_len, block_k))
out_shapes_kv = [
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel_dkv, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid_kv,
out_shape=out_shapes_kv,
in_specs=[
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)),
],
out_specs=[
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward_kv",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=2)(q, k, v, out, do_scaled, l, m, delta)
else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv
Expand Down