From 47bff4bd617427147db1c070b620dd4eeb6413e4 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 3 Dec 2024 09:40:20 -0800 Subject: [PATCH] [Mosaic GPU] Add an initial skeleton for a layout inference pass. Layouts are added as annotations on MLIR ops, using the `in_layouts` and `out_layouts` attributes. At this point, layout inference is done in two passes: one "backwards" pass (root-to-parameters), and one "forward" pass (parameters-to-root). Each pass goes through all the ops in the specified order, and infers a possible layout from the layout information that is available. We expect to need two passes because partial layout annotations may be provided on intermediate nodes (e.g. `wgmma`), and a single pass from the root to the parameters is therefore insufficient to properly annotate all the operations. We do not perform any check as to whether the inferred layouts can be further lowered correctly---meaning that the produced IR can possibly fail to lower later. Layouts are only inferred for ops involving at least one operand or result of type `VectorType`/`RankedTensorType`. When layouts can't be inferred for an op that should have them, we default to annotating it with strided fragmented layouts. PiperOrigin-RevId: 702370349 --- jax/experimental/mosaic/gpu/__init__.py | 12 +- .../mosaic/gpu/dialect_lowering.py | 176 +++++++++++++++++- tests/mosaic/gpu_dialect_test.py | 78 +++++++- 3 files changed, 257 insertions(+), 9 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index c1daa33576bb..03869072f407 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -31,10 +31,18 @@ if dialect is not None: from .dialect_lowering import ( gpu_address_space_to_nvptx as gpu_address_space_to_nvptx, - lower_mgpu_dialect as lower_mgpu_dialect + infer_layout, + lower_mgpu_dialect as lower_mgpu_dialect, + splat_fragmented_layout, + strided_fragmented_layout, ) else: - gpu_address_space_to_nvptx, lower_mgpu_dialect = None, None + gpu_address_space_to_nvptx = None + infer_layout = None + lower_mgpu_dialect = None + splat_fragmented_layout = None + strided_fragmented_layout = None + from .fragmented_array import ( FragmentedArray as FragmentedArray, diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 9bda5b5b7191..d9c5b54cbff6 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -15,22 +15,188 @@ """Lowering rules and pass for the MLIR Mosaic GPU dialect.""" from collections.abc import Callable +import enum import functools +import itertools import operator -from typing import Sequence, Type +from typing import List, Sequence, Tuple, Type, cast from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import gpu +from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import nvvm -from jaxlib.mlir import ir -from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import nvvm from .utils import c, ptr_as_memref, single_thread_predicate # mypy: ignore-errors +def strided_fragmented_layout(): + layout = mgpu.FragmentedLayout.WGStridedFragLayout + return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>") + + +def splat_fragmented_layout(): + layout = mgpu.FragmentedLayout.WGSplatFragLayout + return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>") + + +_layout_inference_rules: dict[ + str, + Callable[[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None], +] = {} + + +def _add_layout_inference_rule( + op: Type[ir.OpView], + rule: Callable[ + [ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None + ], +): + _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + + +def _set_layout_attributes( + op: ir.OpView, + in_layouts: List[ir.Attribute], + out_layouts: List[ir.Attribute], +): + op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) + op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) + + +def _extract_any_layout_from_op(op: ir.OpView) -> ir.Attribute | None: + if "in_layouts" in op.attributes and len(op.operands) > 0: + return cast(ir.ArrayAttr, op.attributes["in_layouts"])[0] + elif "out_layouts" in op.attributes and len(op.results) > 0: + return cast(ir.ArrayAttr, op.attributes["out_layouts"])[0] + + return None + + +def _infer_pointwise_op_layouts( + op: ir.OpView, +) -> Tuple[List[ir.Attribute], List[ir.Attribute]] | None: + layout = _extract_any_layout_from_op(op) + # The op had no layout set. Since we're annotating ops, we may need to + # derive layout information from user or producer ops. + if layout is None: + # First, we iterate on users. + for op_result in op.results: + for op_user in cast(ir.OpResult, op_result).uses: + layout = _extract_any_layout_from_op(op_user.owner) + if layout: + break + else: + continue + break + + if layout is None: + # Still no layout set. We iterate on producers. + for operand in op.operands: + layout = _extract_any_layout_from_op(operand.owner) + if layout: + break + + if layout is None: + return None + + return ([layout for _ in op.operands], [layout for _ in op.results]) + + +for op in ( + arith.AddFOp, + arith.ConstantOp, + arith.MulFOp, +): + _add_layout_inference_rule(op, _infer_pointwise_op_layouts) + + +def _layout_inference_should_process_op(op: ir.OpView) -> bool: + """Returns 'true' if the layout inference pass can skip the operation.""" + + def is_array(v: ir.Value): + ty = v.type + return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty) + + return any(map(is_array, itertools.chain(op.operands, op.results))) + + +def _has_any_layout_set(op: ir.OpView) -> bool: + return "in_layouts" in op.attributes or "out_layouts" in op.attributes + + +class TraversalOrder(enum.Enum): + """Traversal orders with respect to the data flow for IR.""" + + FORWARD = 1 + BACKWARDS = 2 + + +def traverse_op( + op: ir.OpView, + callback: Callable[[ir.OpView], None], + traversal_order: TraversalOrder = TraversalOrder.FORWARD, +): + """Traverses the operation and applies the callback in the given order.""" + for region in op.operation.regions: + for block in region: + if traversal_order == TraversalOrder.FORWARD: + ops_to_traverse = block + else: + ops_to_traverse = reversed(list(block)) + for block_op in ops_to_traverse: + callback(block_op) + callback(op) + + +def infer_layout(module: ir.Module): + def inference_step(op: ir.Operation): + if not _layout_inference_should_process_op(op): + return + elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error + pass + else: + raise NotImplementedError(f"Can not infer layout for {op}") + + maybe_layouts = inference_rule(op) + if maybe_layouts is None: + return + + _set_layout_attributes(op, *maybe_layouts) + + # We run two passes over the module, in order to make sure that layouts + # defined in the middle of the computation are propagated wherever they need + # to be propagated. We start with a backwards (root-to-parameters) pass to + # propagate the information as far up as possible, and then a forward pass + # (parameters-to-root). + # + # Backwards pass + for op in module.body: + traverse_op(op, inference_step, TraversalOrder.BACKWARDS) + + # Forward pass + for op in module.body: + traverse_op(op, inference_step, TraversalOrder.FORWARD) + + # At this point, layouts have been propagated as far as they could be + # propagated. However, it is possible for some operations to remain + # unannotated---for example, if there were no annotations on any operation in + # the module at the start of this function. We annotate all the remaining ops + # that should be annotated with a strided fragmented layout. + def set_default_layout(op: ir.OpView): + layout = strided_fragmented_layout() + if _layout_inference_should_process_op(op) and not _has_any_layout_set(op): + _set_layout_attributes( + op, [layout] * len(op.operands), [layout] * len(op.results)) + + for op in module.body: + traverse_op(op, set_default_layout) + + MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]] diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 8305cd292baf..319af4648002 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -30,7 +30,10 @@ from jax._src.lib.mlir.dialects import scf from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import +from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import +from jax.experimental.mosaic.gpu import splat_fragmented_layout # pylint: disable=g-importing-member +from jax.experimental.mosaic.gpu import strided_fragmented_layout # pylint: disable=g-importing-member _cext = mgpu._cext if mgpu is not None else None @@ -79,7 +82,7 @@ def workgroup_ptr_ty() -> ir.Type: return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") -class DialectTest(parameterized.TestCase): +class MosaicGpuTest(parameterized.TestCase): def setUp(self): if mgpu is None: @@ -89,6 +92,9 @@ def setUp(self): self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() + +class DialectTest(MosaicGpuTest): + def test_dialect_module_is_loaded(self): self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu")) @@ -575,7 +581,7 @@ def test_wgmma_b_shape_dim_2(self): # TODO(b/381371456): Add tests for the other WGMMA inputs. -class DialectLoweringTest(DialectTest): +class DialectLoweringTest(MosaicGpuTest): def test_lowering_removes_mosaic_gpu_ops(self): with ir.InsertionPoint(self.module.body): @@ -642,5 +648,73 @@ def test_initialize_barrier_op_lowering_rule(self): self.assertEqual(count.literal_value, arrival_count) +class LayoutInferenceTest(MosaicGpuTest): + + @parameterized.parameters(ir.RankedTensorType, ir.VectorType) + def test_infer_layout_default(self, type_constructor): + shape = (4, 8) + elt_type = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + ab_type = type_constructor.get(shape, elt_type) + const_zero = ir.FloatAttr.get(elt_type, 0) + const_one = ir.FloatAttr.get(elt_type, 1) + a = arith.ConstantOp( + ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_zero) + ) + b = arith.ConstantOp( + ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_one) + ) + arith.addf(arith.addf(a, b), b) + + # Not setting any layouts on the module should default in ops having a + # strided fragmented layout. + infer_layout(self.module) + + layout = strided_fragmented_layout() + for op in self.module.body.operations: + self.assertIn("in_layouts", op.attributes) + self.assertIn("out_layouts", op.attributes) + + self.assertSequenceEqual( + op.attributes["in_layouts"], [layout] * len(op.operands) + ) + self.assertSequenceEqual( + op.attributes["out_layouts"], [layout] * len(op.results) + ) + + @parameterized.parameters(ir.RankedTensorType, ir.VectorType) + def test_infer_layout_for_pointwise_op(self, type_constructor): + shape = (4, 8) + elt_type = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + ab_type = type_constructor.get(shape, elt_type) + const_zero = ir.FloatAttr.get(elt_type, 0) + const_one = ir.FloatAttr.get(elt_type, 1) + a = arith.ConstantOp( + ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_zero) + ) + b = arith.ConstantOp( + ab_type, ir.DenseElementsAttr.get_splat(ab_type, const_one) + ) + add = arith.addf(arith.addf(a, b), b) + + layout = splat_fragmented_layout() + add.owner.attributes["out_layouts"] = ir.ArrayAttr.get([layout]) + infer_layout(self.module) + + for op in self.module.body.operations: + self.assertIn("in_layouts", op.attributes) + self.assertIn("out_layouts", op.attributes) + + self.assertSequenceEqual( + op.attributes["in_layouts"], [layout] * len(op.operands) + ) + self.assertSequenceEqual( + op.attributes["out_layouts"], [layout] * len(op.results) + ) + + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader())