Skip to content

Commit

Permalink
[pallas] input_output_aliases now only include refs which have been…
Browse files Browse the repository at this point in the history
… written to

PiperOrigin-RevId: 689323778
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 24, 2024
1 parent bb2e230 commit 717467a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 45 deletions.
48 changes: 48 additions & 0 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
from jax._src.state import types as state_types
from jax._src.state.types import TransformedRef
import jax.numpy as jnp

Expand Down Expand Up @@ -1070,6 +1071,53 @@ def _core_map_abstract_eval(*args, jaxpr, mesh):


_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}


def default_mesh_discharge_rule(
in_avals,
out_avals,
*args,
grid,
compiler_params,
backend,
jaxpr,
):
"""Discharges a ``core_map`` over a mesh to a ``pallas_call``."""
del out_avals # Unused.

def body(*args):
# Due to aliasing, ``args`` contains aliased inputs and outputs so we
# remove outputs.
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, in_refs)

assert len(jaxpr.outvars) == 0
modified_idxs = sorted(
eff.input_index
for eff in jaxpr.effects
if isinstance(eff, state_types.WriteEffect)
)
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
from jax._src.pallas import pallas_call # Avoid circular dependency.
outs = pallas_call.pallas_call(
body,
out_shape=[in_avals[idx] for idx in modified_idxs],
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(modified_idxs),
input_output_aliases={
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
},
grid=grid,
compiler_params=compiler_params,
backend=backend,
)(*args)
# ``outs`` lacks the unmodified inputs. Add them back in.
all_outs = [*args]
for out_idx, in_idx in enumerate(modified_idxs):
all_outs[in_idx] = outs[out_idx]
return all_outs, ()


@state_discharge.register_discharge_rule(core_map_p)
def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs):
if type(mesh) not in _core_map_mesh_rules:
Expand Down
31 changes: 10 additions & 21 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import jax.numpy as jnp
import numpy as np

# TODO(b/375357542): Remove the import once the bug is fixed.
_ = pallas_call

map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

Expand Down Expand Up @@ -250,33 +253,19 @@ def _tensorcore_mesh_discharge_rule(
mesh,
jaxpr,
):
del out_avals
assert isinstance(mesh, TensorCoreMesh)
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
def body(*args):
# Due to aliasing, args contains aliased inputs and outputs so we remove
# outputs.
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, in_refs)
assert len(jaxpr.outvars) == 0
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
out_specs=[pallas_core.BlockSpec(
memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
return pallas_core.default_mesh_discharge_rule(
in_avals,
out_avals,
*args,
jaxpr=jaxpr,
grid=((core_axis_name, num_cores),),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel",)),
),
compiler_params=TPUCompilerParams(dimension_semantics=("parallel",)),
backend="mosaic_tpu",
)(*args)
return out, ()
)

pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
_tensorcore_mesh_discharge_rule
Expand Down
37 changes: 13 additions & 24 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from jax._src import effects
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.state.types import Transform
from jax._src.state import indexing
from jax._src.state import types as state_types
import jax.experimental.mosaic.gpu as mgpu
Expand Down Expand Up @@ -155,7 +153,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform:

@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class UntileRef(Transform):
class UntileRef(state_types.Transform):
tiling: tuple[int, ...]

def transform_shape(self, shape):
Expand All @@ -173,7 +171,7 @@ def transform_dtype(self, dtype):

def untransform_index(
self, idxs: tuple[Index, ...]
) -> tuple[tuple[Index, ...], Transform]:
) -> tuple[tuple[Index, ...], state_types.Transform]:
untiled_idxs = idxs[: -len(self.tiling)]
tiled_idxs = idxs[-len(self.tiling) :]
idxs_after_tiling = []
Expand Down Expand Up @@ -231,7 +229,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform:

@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class TransposeRef(Transform):
class TransposeRef(state_types.Transform):
permutation: tuple[int, ...]

def transform_shape(self, shape):
Expand All @@ -244,7 +242,7 @@ def transform_dtype(self, dtype):

def untransform_index(
self, idxs: tuple[Index, ...]
) -> tuple[tuple[Index, ...], Transform]:
) -> tuple[tuple[Index, ...], state_types.Transform]:
removed_dims = [
i for i, idx in enumerate(idxs) if not isinstance(idx, slice)
]
Expand Down Expand Up @@ -316,12 +314,12 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:

@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class UnswizzleRef(Transform):
class UnswizzleRef(state_types.Transform):
swizzle: int

def untransform_index(
self, idxs: tuple[Index, ...]
) -> tuple[tuple[Index, ...], Transform]:
) -> tuple[tuple[Index, ...], state_types.Transform]:
if not idxs:
return idxs, self
if not all(isinstance(idx, slice) for idx in idxs[-2:]):
Expand Down Expand Up @@ -523,29 +521,20 @@ def _gpu_mesh_discharge_rule(
mesh,
jaxpr,
):
del out_avals
assert isinstance(mesh, GPUMesh)
if mesh.cluster:
raise NotImplementedError
if mesh.num_threads is None:
raise NotImplementedError
def body(*args):
# Due to aliasing, args contains aliased inputs and outputs so we remove
# outputs.
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, in_refs)
assert len(jaxpr.outvars) == 0
any_spec = pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
return pallas_core.default_mesh_discharge_rule(
in_avals,
out_avals,
*args,
jaxpr=jaxpr,
grid=tuple(mesh.shape.items()),
backend="mosaic_gpu",
)(*args)
return out, ()
compiler_params=GPUCompilerParams(),
)

pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule

Expand Down

0 comments on commit 717467a

Please sign in to comment.