Skip to content

Commit

Permalink
[Pallas:MGPU] Add support for indexed refs to WGMMA
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687258992
  • Loading branch information
apaszke authored and Google-ML-Automation committed Oct 18, 2024
1 parent f2edc83 commit 0ee9531
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 49 deletions.
100 changes: 52 additions & 48 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import enum
from typing import Any, Literal

import jax
from jax._src import core as jax_core
from jax._src import effects
from jax._src import state
Expand Down Expand Up @@ -391,41 +392,25 @@ def wgmma(
f" rhs={b.shape=}, acc={acc.shape}"
)

if (dtype := a.dtype) != b.dtype:
if a.dtype != b.dtype:
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")

# Infer swizzle from a.
if not a.transforms or not isinstance(
(swizzle_transform := a.transforms[0]), gpu_core.UnswizzleRef
):
raise ValueError("WGMMA lhs must be a tiled and swizzled reference.")
a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms)
b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms)

swizzle = swizzle_transform.swizzle
swizzle_elems = swizzle // dtype.itemsize
if a.transforms[1:] != (gpu_core.UntileRef((64, swizzle_elems)),):
raise ValueError(
f"WGMMA lhs must be tiled with 64x{swizzle_elems} tiles for element type"
f" {dtype}."
)

rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3))
rhs_tiling = gpu_core.UntileRef((swizzle_elems, swizzle_elems))
if b.transforms == (swizzle_transform, rhs_tiling):
rhs_transpose = False
elif b.transforms == (swizzle_transform, rhs_transpose_transform, rhs_tiling):
rhs_transpose = True
else:
raise ValueError(
f"WGMMA rhs must have {swizzle=} and be tiled with"
f" {swizzle_elems}x{swizzle_elems} tiles for element type {dtype} (and"
" optionally transposed)."
)

wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose)
wgmma_ref_p.bind(
acc,
a.ref,
b.ref,
*a_transforms_leaves,
*b_transforms_leaves,
a_transforms_tree=a_transforms_tree,
b_transforms_tree=b_transforms_tree,
)


@wgmma_ref_p.def_effectful_abstract_eval
def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, **params):
def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params):
del a_aval, b_aval, params
if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}")
Expand All @@ -439,23 +424,9 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, **params):


@discharge.register_discharge_rule(wgmma_ref_p)
def _wgmma_ref_discharge(
in_avals,
out_avals,
acc,
a,
b,
swizzle,
rhs_transpose,
):
def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs):
del in_avals, out_avals
return (
wgmma_p.bind(
acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose
),
None,
None,
), []
return (wgmma_p.bind(*args, **kwargs), *([None] * (len(args) - 1))), []


# Functional WGMMA, returns a shaped array. Internal.
Expand All @@ -468,10 +439,43 @@ def _wgmma_lowering(
acc,
a,
b,
swizzle,
rhs_transpose,
*transforms_leaves,
a_transforms_tree,
b_transforms_tree,
):
del ctx
_, a_aval, *_ = ctx.avals_in
a_transforms_leaves, b_transforms_leaves = util.split_list(
transforms_leaves, [a_transforms_tree.num_leaves]
)
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)

a, a_transforms = lowering._handle_indexing(a, a_transforms)
b, b_transforms = lowering._handle_indexing(b, b_transforms)

match a_transforms:
case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
swizzle_elems = swizzle // a_aval.dtype.itemsize
if tiling != (64, swizzle_elems):
raise NotImplementedError("WGMMA lhs tiling does not fit swizzle")
case _:
raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.")

match b_transforms:
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
rhs_transpose = False
# TODO(apaszke): Actually what we really want to test here is that we're
# only doing transposes within the tiles!
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.TransposeRef((1, 0, 2, 3)), gpu_core.UntileRef(rhs_tiling)):
rhs_transpose = True
case _:
raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.")

if rhs_swizzle != swizzle:
raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle")
if rhs_tiling != (swizzle_elems, swizzle_elems):
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")

new_acc = mgpu.wgmma(
acc,
a,
Expand Down
37 changes: 36 additions & 1 deletion tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,42 @@ def scope(acc_ref):
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
)

def test_wgmma_sliced(self):
def test_wgmma_sliced_ref(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0])
return acc_ref[...]

o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16)

res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(2, 64, 128), lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(2, 128, 192), lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
)(a, b)
np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)

def test_wgmma_sliced_acc(self):
swizzle = 128
elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize
def kernel(a_ref, b_ref, o_ref):
Expand Down

0 comments on commit 0ee9531

Please sign in to comment.