Skip to content

Commit

Permalink
add back previous attention kernel as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu95 committed Jun 22, 2023
1 parent c65144b commit 7121d59
Showing 1 changed file with 100 additions and 3 deletions.
103 changes: 100 additions & 3 deletions jax_triton/pallas/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,66 @@ def _preprocess_backward(out, do, l, block_q: int,
name="mha_preprocess_backward")(out, do, l)
return do_scaled, delta

def mha_backward_kernel(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref, _,
# Outputs
dq_ref, dk_ref, dv_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]

def outer_loop(start_k, _):

dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)

def inner_loop(start_q, carry):
dv, dk = carry
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
if sm_scale != 1.0:
qk *= sm_scale
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
p = jnp.exp(qk - m[:, None])
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
dv = dv + pl.dot(p.astype(do.dtype).T, do)
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), eviction_policy="evict_last")
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
return dv, dk
if causal:
lower_bound = lax.div(start_k * block_k, block_q)
else:
lower_bound = 0
dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop,
(dv, dk))
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None)

def mha_backward_kernel_dq(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
Expand Down Expand Up @@ -282,7 +342,6 @@ def inner_loop(start_k, carry):
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")


def mha_backward_kernel_dkv(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
Expand Down Expand Up @@ -336,7 +395,6 @@ def inner_loop(start_q, carry):
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))


def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
num_stages: int, grid: Any, interpret: bool,
Expand All @@ -353,6 +411,45 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
return jax.vjp(functools.partial(mha_reference, sm_scale=sm_scale,
causal=causal), q, k, v)[1](do)
elif backward_pass_impl == "triton":
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)
out_shapes = [
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]

grid = (batch_size, num_heads)
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
num_warps = 4
dq, dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid,
out_shape=out_shapes,
in_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
elif backward_pass_impl == "triton_split":
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)
out_shapes_q = jax.ShapeDtypeStruct(dq.shape, dq.dtype)
Expand Down Expand Up @@ -433,4 +530,4 @@ def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True):
mask = jnp.broadcast_to(mask, logits.shape)
logits = jnp.where(mask, logits, float('-inf'))
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)

0 comments on commit 7121d59

Please sign in to comment.