diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index e897ffba..0b7df6c6 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -190,7 +190,7 @@ def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int, def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, new_dout_ref, delta_ref, *, block_q: int): - pid_m = pl.program_id(0) + pid_m = pl.program_id(2) off_m = pl.ds(pid_m * block_q, block_q) # load @@ -214,15 +214,15 @@ def _preprocess_backward(out, do, l, block_q: int, ] do_scaled, delta = pl.pallas_call( functools.partial(_preprocess_backward_kernel, block_q=block_q), - grid=(jt.cdiv(seq_len, block_q), batch_size, num_heads), + grid=(batch_size, num_heads, jt.cdiv(seq_len, block_q)), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), ], num_warps=4, num_stages=3, @@ -232,6 +232,7 @@ def _preprocess_backward(out, do, l, block_q: int, name="mha_preprocess_backward")(out, do, l) return do_scaled, delta + def mha_backward_kernel( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, @@ -292,6 +293,109 @@ def inner_loop(start_q, carry): slice(None)), dk.astype(dk_ref.dtype)) lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None) + +def mha_backward_kernel_dq( + # Inputs + q_ref, k_ref, v_ref, out_ref, do_scaled_ref, + l_ref, m_ref, delta_ref, + # Outputs + dq_ref, + *, sm_scale: float, causal: bool, + block_q: int, block_d: int, block_k: int +): + del out_ref, l_ref # Not needed + seq_len = q_ref.shape[0] + start_q = pl.program_id(2) + q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + span_q = start_q * block_q + jnp.arange(block_q) + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) + do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) + dq = jnp.zeros([block_q, block_d], dtype=jnp.float32) + + def inner_loop(start_k, dq): + span_k = start_k * block_k + jnp.arange(block_k) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + if causal: + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + p = jnp.exp(qk - m[:, None]) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + return dq + if causal: + upper_bound = lax.div(start_q * block_q, block_k) + 1 + else: + upper_bound = jt.cdiv(seq_len, block_k) + dq = lax.fori_loop(0, upper_bound, inner_loop, dq) + pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), dq, eviction_policy="evict_last") + + +def mha_backward_kernel_dkv( + # Inputs + q_ref, k_ref, v_ref, out_ref, do_scaled_ref, + l_ref, m_ref, delta_ref, + # Outputs + dk_ref, dv_ref, + *, sm_scale: float, causal: bool, + block_q: int, block_d: int, block_k: int +): + del out_ref, l_ref # Not needed + seq_len = q_ref.shape[0] + start_k = pl.program_id(2) + + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + span_k = start_k * block_k + jnp.arange(block_k) + + def inner_loop(start_q, carry): + dv, dk = carry + q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) + p = jnp.exp(qk - m[:, None]) + do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + return dv, dk + if causal: + lower_bound = lax.div(start_k * block_k, block_q) + else: + lower_bound = 0 + dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop, + (dv, dk)) + pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), + slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), + slice(None)), dk.astype(dk_ref.dtype)) + + def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, backward_pass_impl: str, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, @@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, num_warps=num_warps, num_stages=1, input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + elif backward_pass_impl == "triton_split": + # We accumulate into dq so we need to initialize it to zeros. + out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32) + + grid_q = (batch_size, num_heads, jt.cdiv(seq_len, block_q)) + num_warps = 8 + dq = pl.pallas_call( + functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim, + block_k=block_k, sm_scale=sm_scale, causal=causal), + grid=grid_q, + out_shape=out_shapes_q, + in_specs=[ + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + ], + out_specs=[ + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + name="mha_backward_q", + debug=debug, + interpret=interpret, + num_warps=num_warps, + num_stages=2)(q, k, v, out, do_scaled, l, m, delta) + + grid_kv = (batch_size, num_heads, jt.cdiv(seq_len, block_k)) + out_shapes_kv = [ + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), + ] + dk, dv = pl.pallas_call( + functools.partial(mha_backward_kernel_dkv, block_q=block_q, block_d=head_dim, + block_k=block_k, sm_scale=sm_scale, causal=causal), + grid=grid_kv, + out_shape=out_shapes_kv, + in_specs=[ + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + ], + out_specs=[ + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + name="mha_backward_kv", + debug=debug, + interpret=interpret, + num_warps=num_warps, + num_stages=2)(q, k, v, out, do_scaled, l, m, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv