Skip to content

Commit

Permalink
[Pallas:MGPU] Properly commute indexing with other transforms
Browse files Browse the repository at this point in the history
Doing so requires us to modify the other transforms when we attempt to
move indexing before them.

PiperOrigin-RevId: 687240515
  • Loading branch information
apaszke authored and Google-ML-Automation committed Oct 18, 2024
1 parent 4db212d commit f2edc83
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 62 deletions.
39 changes: 28 additions & 11 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,22 @@ def transform_shape(self, shape):
def transform_dtype(self, dtype):
return dtype

def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
if not all(isinstance(idx, slice) for idx in idxs):
raise NotImplementedError("Non-slice indices are not supported")
def untransform_index(
self, idxs: tuple[Index, ...]
) -> tuple[tuple[Index, ...], Transform]:
untiled_idxs = idxs[: -len(self.tiling)]
tiled_idxs = idxs[-len(self.tiling) :]
idxs_after_tiling = []
for idx, tile in zip(tiled_idxs, self.tiling):
if not isinstance(idx, slice):
raise NotImplementedError("Non-slice indices are not supported")
assert isinstance(idx, slice)
if idx.step is not None and idx.step != 1:
raise NotImplementedError("Strided slices unsupported")
if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile):
raise ValueError("Non-empty slices must be tile aligned")
idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile))
return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling))
return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self

def undo_to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)
Expand Down Expand Up @@ -197,8 +199,19 @@ def transform_shape(self, shape):
def transform_dtype(self, dtype):
return dtype

def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
return tuple(idxs[i] for i in _perm_inverse(self.permutation))
def untransform_index(
self, idxs: tuple[Index, ...]
) -> tuple[tuple[Index, ...], Transform]:
removed_dims = [
i for i, idx in enumerate(idxs) if not isinstance(idx, slice)
]
new_perm = tuple(
p - sum(d < p for d in removed_dims)
for p in self.permutation
if p not in removed_dims
)
new_idxs = tuple(idxs[i] for i in _perm_inverse(self.permutation))
return new_idxs, TransposeRef(new_perm)

def undo_to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TransposeTransform(_perm_inverse(self.permutation))
Expand Down Expand Up @@ -250,11 +263,15 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
class UnswizzleRef(Transform):
swizzle: int

def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
def untransform_index(
self, idxs: tuple[Index, ...]
) -> tuple[tuple[Index, ...], Transform]:
if not idxs:
return idxs
if not all(isinstance(idx, slice) for idx in idxs):
raise NotImplementedError("Non-slice indices are not supported")
return idxs, self
if not all(isinstance(idx, slice) for idx in idxs[-2:]):
raise NotImplementedError(
"Non-slice indices are not supported in 2 minormost dims"
)
last_idx = idxs[-1]
assert isinstance(last_idx, slice)
if last_idx.step is not None and last_idx.step != 1:
Expand All @@ -263,7 +280,7 @@ def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
last_idx.stop is not None and last_idx.stop != self.swizzle
):
raise ValueError("Swizzled dims cannot be sliced")
return idxs
return idxs, self

def tree_flatten(self):
return (), (self.swizzle,)
Expand Down
84 changes: 40 additions & 44 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,11 +858,11 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):

def _handle_indexing(
ref: ir.Value, transforms: Sequence[gpu_core.Transform]
) -> ir.Value:
) -> tuple[ir.Value, Sequence[gpu_core.Transform]]:
if not transforms:
pass
if not any(isinstance(t, indexing.NDIndexer) for t in transforms):
return ref
return ref, transforms
if any(
isinstance(t, indexing.NDIndexer) for t in transforms[:-1]
) or not isinstance(transforms[-1], indexing.NDIndexer):
Expand All @@ -872,9 +872,11 @@ def _handle_indexing(
if indexer.int_indexer_shape:
raise NotImplementedError("int_indexer_shape non-empty")
indices = _ndindexer_indices(indexer)
new_transforms_rev = []
for t in reversed(transforms[:-1]):
indices = t.untransform_index(indices)
return mgpu.memref_slice(ref, indices)
indices, new_t = t.untransform_index(indices)
new_transforms_rev.append(new_t)
return mgpu.memref_slice(ref, indices), new_transforms_rev[::-1]


def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]:
Expand All @@ -896,35 +898,26 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
return tuple(indices)


def _is_swizzled(transforms: tuple[gpu_core.Transform, ...]) -> int | None:
if not transforms:
return None
if any(isinstance(t, gpu_core.UnswizzleRef) for t in transforms[1:]):
raise NotImplementedError(
"Swizzling must be the last transform applied to a ref"
)
if isinstance(t := transforms[0], gpu_core.UnswizzleRef):
return t.swizzle
return None


@register_lowering_rule(sp.get_p)
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
raise TypeError(f"Can only load from references (got {x_smem}).")
x_aval = ctx.avals_in[0]
transform = jax.tree.unflatten(tree, leaves)
swizzle = _is_swizzled(transform)
x_smem = _handle_indexing(x_smem, transform)
# TODO(apaszke): The swizzled case should also check tiling!
if swizzle is None:
return mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
)
else:
return mgpu.FragmentedArray.load_tiled(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle
)
transforms = jax.tree.unflatten(tree, leaves)
x_smem, transforms = _handle_indexing(x_smem, transforms)
match transforms:
case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
if tiling != (64, swizzle // x_aval.dtype.itemsize):
raise NotImplementedError("Tiling does not fit swizzle")
return mgpu.FragmentedArray.load_tiled(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle
)
case ():
return mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
)
case _:
raise NotImplementedError(f"Unsupported transforms: {transforms}")


@register_lowering_rule(sp.swap_p)
Expand All @@ -935,23 +928,26 @@ def _swap_lowering_rule(
raise TypeError(f"Can only store arrays (got {value}).")
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
raise TypeError(f"Can only store to references (got {x_smem}).")
transforms = jax.tree.unflatten(tree, leaves)
swizzle = _is_swizzled(transforms)
x_smem = _handle_indexing(x_smem, transforms)
x_aval = ctx.avals_in[0]
# TODO(apaszke): The swizzled case should also check tiling!
if swizzle is None:
old_value = mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
)
value.store_untiled(x_smem)
return old_value
else:
old_value = mgpu.FragmentedArray.load_tiled(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle
)
value.store_tiled(x_smem, swizzle=swizzle)
return old_value
transforms = jax.tree.unflatten(tree, leaves)
x_smem, transforms = _handle_indexing(x_smem, transforms)
match transforms:
case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
if tiling != (64, swizzle // x_aval.dtype.itemsize):
raise NotImplementedError("Tiling does not fit swizzle")
old_value = mgpu.FragmentedArray.load_tiled(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle
)
value.store_tiled(x_smem, swizzle=swizzle)
return old_value
case ():
old_value = mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
)
value.store_untiled(x_smem)
return old_value
case _:
raise NotImplementedError(f"Unsupported transforms: {transforms}")


@register_lowering_rule(pjit.pjit_p)
Expand Down
15 changes: 8 additions & 7 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _copy_smem_to_gmem_lowering(
)
src_transforms = src_transforms_treedef.unflatten(flat_src_transforms)
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
src = lowering._handle_indexing(src, src_transforms)
src, src_transforms = lowering._handle_indexing(src, src_transforms)
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
mgpu.commit_shared()
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
Expand All @@ -80,11 +80,12 @@ def _extract_gmem_copy_params(transforms):
def _extract_smem_copy_params(transforms):
if not transforms:
return {}
if isinstance(transforms[-1], indexing.NDIndexer):
transforms = transforms[:-1]
swizzle = lowering._is_swizzled(transforms)
if swizzle is not None:
transforms = transforms[1:]
# Split off swizzling, if present
match transforms:
case [gpu_core.UnswizzleRef(swizzle), *transforms]:
pass
case _:
swizzle = None
gpu_transforms = tuple(t.undo_to_gpu_transform() for t in transforms[::-1])
return dict(
gmem_transform=gpu_transforms,
Expand Down Expand Up @@ -159,7 +160,7 @@ def _copy_gmem_to_smem_lowering(
)
src_transforms = src_transforms_treedef.unflatten(flat_src_transforms)
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
dst = lowering._handle_indexing(dst, dst_transforms)
dst, dst_transforms = lowering._handle_indexing(dst, dst_transforms)
copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms)
barrier_indexer = _extract_barrier_indexer(
barrier_transforms_treedef.unflatten(flat_barrier_transforms)
Expand Down
27 changes: 27 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,33 @@ def kernel(x_ref, o_ref, barrier_ref):
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x)

def test_copy_with_transforms_and_indexing(self):
def kernel(x_ref, o_ref, barrier_ref):
for i in range(2):
plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)

in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(2, 128, 128),
lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 32)),
plgpu.TransposeTransform((0, 2, 1, 3, 4)),
plgpu.SwizzleTransform(128),
),
memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0))

def test_copy_gmem_to_smem_in_run_scoped(self):
@functools.partial(
pl.pallas_call,
Expand Down

0 comments on commit f2edc83

Please sign in to comment.