Skip to content

Commit

Permalink
[Mosaic GPU] Add an initial skeleton for a layout inference pass.
Browse files Browse the repository at this point in the history
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.

At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).

Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.

We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.

Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.

When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.

PiperOrigin-RevId: 702370349
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Dec 11, 2024
1 parent b79dae8 commit 47bff4b
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 9 deletions.
12 changes: 10 additions & 2 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@
if dialect is not None:
from .dialect_lowering import (
gpu_address_space_to_nvptx as gpu_address_space_to_nvptx,
lower_mgpu_dialect as lower_mgpu_dialect
infer_layout,
lower_mgpu_dialect as lower_mgpu_dialect,
splat_fragmented_layout,
strided_fragmented_layout,
)
else:
gpu_address_space_to_nvptx, lower_mgpu_dialect = None, None
gpu_address_space_to_nvptx = None
infer_layout = None
lower_mgpu_dialect = None
splat_fragmented_layout = None
strided_fragmented_layout = None


from .fragmented_array import (
FragmentedArray as FragmentedArray,
Expand Down
176 changes: 171 additions & 5 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,188 @@
"""Lowering rules and pass for the MLIR Mosaic GPU dialect."""

from collections.abc import Callable
import enum
import functools
import itertools
import operator
from typing import Sequence, Type
from typing import List, Sequence, Tuple, Type, cast

from jax._src.interpreters import mlir as mlir_interpreter
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import nvvm

from jaxlib.mlir import ir
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import nvvm
from .utils import c, ptr_as_memref, single_thread_predicate

# mypy: ignore-errors


def strided_fragmented_layout():
layout = mgpu.FragmentedLayout.WGStridedFragLayout
return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>")


def splat_fragmented_layout():
layout = mgpu.FragmentedLayout.WGSplatFragLayout
return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>")


_layout_inference_rules: dict[
str,
Callable[[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None],
] = {}


def _add_layout_inference_rule(
op: Type[ir.OpView],
rule: Callable[
[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None
],
):
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error


def _set_layout_attributes(
op: ir.OpView,
in_layouts: List[ir.Attribute],
out_layouts: List[ir.Attribute],
):
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)


def _extract_any_layout_from_op(op: ir.OpView) -> ir.Attribute | None:
if "in_layouts" in op.attributes and len(op.operands) > 0:
return cast(ir.ArrayAttr, op.attributes["in_layouts"])[0]
elif "out_layouts" in op.attributes and len(op.results) > 0:
return cast(ir.ArrayAttr, op.attributes["out_layouts"])[0]

return None


def _infer_pointwise_op_layouts(
op: ir.OpView,
) -> Tuple[List[ir.Attribute], List[ir.Attribute]] | None:
layout = _extract_any_layout_from_op(op)
# The op had no layout set. Since we're annotating ops, we may need to
# derive layout information from user or producer ops.
if layout is None:
# First, we iterate on users.
for op_result in op.results:
for op_user in cast(ir.OpResult, op_result).uses:
layout = _extract_any_layout_from_op(op_user.owner)
if layout:
break
else:
continue
break

if layout is None:
# Still no layout set. We iterate on producers.
for operand in op.operands:
layout = _extract_any_layout_from_op(operand.owner)
if layout:
break

if layout is None:
return None

return ([layout for _ in op.operands], [layout for _ in op.results])


for op in (
arith.AddFOp,
arith.ConstantOp,
arith.MulFOp,
):
_add_layout_inference_rule(op, _infer_pointwise_op_layouts)


def _layout_inference_should_process_op(op: ir.OpView) -> bool:
"""Returns 'true' if the layout inference pass can skip the operation."""

def is_array(v: ir.Value):
ty = v.type
return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty)

return any(map(is_array, itertools.chain(op.operands, op.results)))


def _has_any_layout_set(op: ir.OpView) -> bool:
return "in_layouts" in op.attributes or "out_layouts" in op.attributes


class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""

FORWARD = 1
BACKWARDS = 2


def traverse_op(
op: ir.OpView,
callback: Callable[[ir.OpView], None],
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
):
"""Traverses the operation and applies the callback in the given order."""
for region in op.operation.regions:
for block in region:
if traversal_order == TraversalOrder.FORWARD:
ops_to_traverse = block
else:
ops_to_traverse = reversed(list(block))
for block_op in ops_to_traverse:
callback(block_op)
callback(op)


def infer_layout(module: ir.Module):
def inference_step(op: ir.Operation):
if not _layout_inference_should_process_op(op):
return
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
pass
else:
raise NotImplementedError(f"Can not infer layout for {op}")

maybe_layouts = inference_rule(op)
if maybe_layouts is None:
return

_set_layout_attributes(op, *maybe_layouts)

# We run two passes over the module, in order to make sure that layouts
# defined in the middle of the computation are propagated wherever they need
# to be propagated. We start with a backwards (root-to-parameters) pass to
# propagate the information as far up as possible, and then a forward pass
# (parameters-to-root).
#
# Backwards pass
for op in module.body:
traverse_op(op, inference_step, TraversalOrder.BACKWARDS)

# Forward pass
for op in module.body:
traverse_op(op, inference_step, TraversalOrder.FORWARD)

# At this point, layouts have been propagated as far as they could be
# propagated. However, it is possible for some operations to remain
# unannotated---for example, if there were no annotations on any operation in
# the module at the start of this function. We annotate all the remaining ops
# that should be annotated with a strided fragmented layout.
def set_default_layout(op: ir.OpView):
layout = strided_fragmented_layout()
if _layout_inference_should_process_op(op) and not _has_any_layout_set(op):
_set_layout_attributes(
op, [layout] * len(op.operands), [layout] * len(op.results))

for op in module.body:
traverse_op(op, set_default_layout)


MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]


Expand Down
78 changes: 76 additions & 2 deletions tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
from jax._src.lib.mlir.dialects import scf
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import splat_fragmented_layout # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import strided_fragmented_layout # pylint: disable=g-importing-member

_cext = mgpu._cext if mgpu is not None else None

Expand Down Expand Up @@ -79,7 +82,7 @@ def workgroup_ptr_ty() -> ir.Type:
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")


class DialectTest(parameterized.TestCase):
class MosaicGpuTest(parameterized.TestCase):

def setUp(self):
if mgpu is None:
Expand All @@ -89,6 +92,9 @@ def setUp(self):
self.enter_context(ir.Location.unknown())
self.module = ir.Module.create()


class DialectTest(MosaicGpuTest):

def test_dialect_module_is_loaded(self):
self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu"))

Expand Down Expand Up @@ -575,7 +581,7 @@ def test_wgmma_b_shape_dim_2(self):
# TODO(b/381371456): Add tests for the other WGMMA inputs.


class DialectLoweringTest(DialectTest):
class DialectLoweringTest(MosaicGpuTest):

def test_lowering_removes_mosaic_gpu_ops(self):
with ir.InsertionPoint(self.module.body):
Expand Down Expand Up @@ -642,5 +648,73 @@ def test_initialize_barrier_op_lowering_rule(self):
self.assertEqual(count.literal_value, arrival_count)


class LayoutInferenceTest(MosaicGpuTest):

@parameterized.parameters(ir.RankedTensorType, ir.VectorType)
def test_infer_layout_default(self, type_constructor):
shape = (4, 8)
elt_type = ir.BF16Type.get()

with ir.InsertionPoint(self.module.body):
ab_type = type_constructor.get(shape, elt_type)
const_zero = ir.FloatAttr.get(elt_type, 0)
const_one = ir.FloatAttr.get(elt_type, 1)
a = arith.ConstantOp(
ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_zero)
)
b = arith.ConstantOp(
ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_one)
)
arith.addf(arith.addf(a, b), b)

# Not setting any layouts on the module should default in ops having a
# strided fragmented layout.
infer_layout(self.module)

layout = strided_fragmented_layout()
for op in self.module.body.operations:
self.assertIn("in_layouts", op.attributes)
self.assertIn("out_layouts", op.attributes)

self.assertSequenceEqual(
op.attributes["in_layouts"], [layout] * len(op.operands)
)
self.assertSequenceEqual(
op.attributes["out_layouts"], [layout] * len(op.results)
)

@parameterized.parameters(ir.RankedTensorType, ir.VectorType)
def test_infer_layout_for_pointwise_op(self, type_constructor):
shape = (4, 8)
elt_type = ir.BF16Type.get()

with ir.InsertionPoint(self.module.body):
ab_type = type_constructor.get(shape, elt_type)
const_zero = ir.FloatAttr.get(elt_type, 0)
const_one = ir.FloatAttr.get(elt_type, 1)
a = arith.ConstantOp(
ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_zero)
)
b = arith.ConstantOp(
ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_one)
)
add = arith.addf(arith.addf(a, b), b)

layout = splat_fragmented_layout()
add.owner.attributes["out_layouts"] = ir.ArrayAttr.get([layout])
infer_layout(self.module)

for op in self.module.body.operations:
self.assertIn("in_layouts", op.attributes)
self.assertIn("out_layouts", op.attributes)

self.assertSequenceEqual(
op.attributes["in_layouts"], [layout] * len(op.operands)
)
self.assertSequenceEqual(
op.attributes["out_layouts"], [layout] * len(op.results)
)


if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 47bff4b

Please sign in to comment.