From 717467a82f32e8f9b9e2b2843fddcf4ed3942473 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Oct 2024 04:17:16 -0700 Subject: [PATCH] [pallas] `input_output_aliases` now only include refs which have been written to PiperOrigin-RevId: 689323778 --- jax/_src/pallas/core.py | 48 ++++++++++++++++++++++++++++++ jax/_src/pallas/mosaic/core.py | 31 +++++++------------ jax/_src/pallas/mosaic_gpu/core.py | 37 ++++++++--------------- 3 files changed, 71 insertions(+), 45 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index fd7883109b9c..dad45bbae207 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -37,6 +37,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state import types as state_types from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -1070,6 +1071,53 @@ def _core_map_abstract_eval(*args, jaxpr, mesh): _core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} + + +def default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + grid, + compiler_params, + backend, + jaxpr, +): + """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" + del out_avals # Unused. + + def body(*args): + # Due to aliasing, ``args`` contains aliased inputs and outputs so we + # remove outputs. + in_refs = args[:len(in_avals)] + jax_core.eval_jaxpr(jaxpr, in_refs) + + assert len(jaxpr.outvars) == 0 + modified_idxs = sorted( + eff.input_index + for eff in jaxpr.effects + if isinstance(eff, state_types.WriteEffect) + ) + any_spec = BlockSpec(memory_space=MemorySpace.ANY) + from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call.pallas_call( + body, + out_shape=[in_avals[idx] for idx in modified_idxs], + in_specs=[any_spec] * len(in_avals), + out_specs=[any_spec] * len(modified_idxs), + input_output_aliases={ + in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) + }, + grid=grid, + compiler_params=compiler_params, + backend=backend, + )(*args) + # ``outs`` lacks the unmodified inputs. Add them back in. + all_outs = [*args] + for out_idx, in_idx in enumerate(modified_idxs): + all_outs[in_idx] = outs[out_idx] + return all_outs, () + + @state_discharge.register_discharge_rule(core_map_p) def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs): if type(mesh) not in _core_map_mesh_rules: diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 60207256fd05..6e16df2e54de 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -32,6 +32,9 @@ import jax.numpy as jnp import numpy as np +# TODO(b/375357542): Remove the import once the bug is fixed. +_ = pallas_call + map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -250,33 +253,19 @@ def _tensorcore_mesh_discharge_rule( mesh, jaxpr, ): - del out_avals assert isinstance(mesh, TensorCoreMesh) if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") core_axis_name, num_cores = list(mesh.shape.items())[0] - def body(*args): - # Due to aliasing, args contains aliased inputs and outputs so we remove - # outputs. - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, in_refs) - assert len(jaxpr.outvars) == 0 - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - out_specs=[pallas_core.BlockSpec( - memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + jaxpr=jaxpr, grid=((core_axis_name, num_cores),), - compiler_params=dict( - mosaic=dict(dimension_semantics=("parallel",)), - ), + compiler_params=TPUCompilerParams(dimension_semantics=("parallel",)), backend="mosaic_tpu", - )(*args) - return out, () + ) pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( _tensorcore_mesh_discharge_rule diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 9d28d72f08ad..ff22f276001f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,8 +29,6 @@ from jax._src import effects from jax._src import tree_util from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call -from jax._src.state.types import Transform from jax._src.state import indexing from jax._src.state import types as state_types import jax.experimental.mosaic.gpu as mgpu @@ -155,7 +153,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class UntileRef(Transform): +class UntileRef(state_types.Transform): tiling: tuple[int, ...] def transform_shape(self, shape): @@ -173,7 +171,7 @@ def transform_dtype(self, dtype): def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] idxs_after_tiling = [] @@ -231,7 +229,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class TransposeRef(Transform): +class TransposeRef(state_types.Transform): permutation: tuple[int, ...] def transform_shape(self, shape): @@ -244,7 +242,7 @@ def transform_dtype(self, dtype): def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: removed_dims = [ i for i, idx in enumerate(idxs) if not isinstance(idx, slice) ] @@ -316,12 +314,12 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class UnswizzleRef(Transform): +class UnswizzleRef(state_types.Transform): swizzle: int def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: if not idxs: return idxs, self if not all(isinstance(idx, slice) for idx in idxs[-2:]): @@ -523,29 +521,20 @@ def _gpu_mesh_discharge_rule( mesh, jaxpr, ): - del out_avals assert isinstance(mesh, GPUMesh) if mesh.cluster: raise NotImplementedError if mesh.num_threads is None: raise NotImplementedError - def body(*args): - # Due to aliasing, args contains aliased inputs and outputs so we remove - # outputs. - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, in_refs) - assert len(jaxpr.outvars) == 0 - any_spec = pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY) - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + jaxpr=jaxpr, grid=tuple(mesh.shape.items()), backend="mosaic_gpu", - )(*args) - return out, () + compiler_params=GPUCompilerParams(), + ) pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule