Skip to content

Commit

Permalink
[Mosaic GPU] Fix layout API bugs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715077057
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 13, 2025
1 parent dabe27b commit f69592a
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,16 @@ class ParameterizedLayout:
kwargs: Any


def _get_mgpu_layout(layout: Layout | ParameterizedLayout
) -> mgpu.FragmentedLayout:
if isinstance(layout, Layout):
return layout.value()
elif isinstance(layout, ParameterizedLayout):
return layout.layout_cls.value(*layout.args,
**layout.kwargs)
else:
raise TypeError(f"Unsupported layout: {layout}")

layout_cast_p = jax_core.Primitive("layout_cast")


Expand All @@ -664,14 +674,7 @@ def _layout_cast_abstract_eval(x, new_layout):
@lowering.register_lowering_rule(layout_cast_p)
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
del ctx # Unused.
if isinstance(new_layout, Layout):
return x.to_layout(new_layout.value())
elif isinstance(new_layout, ParameterizedLayout):
layout = new_layout.layout_cls(*new_layout.args,
**new_layout.kwargs)
return x.to_layout(layout)
else:
raise TypeError(f"Unsupported layout: {new_layout}")
return x.to_layout(_get_mgpu_layout(new_layout))


def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout):
Expand Down Expand Up @@ -753,7 +756,7 @@ def _broadcasted_iota_lowering(
return mgpu.FragmentedArray.splat(
llvm_dialect.mlir_undef(mlir_dtype),
shape,
layout.value,
_get_mgpu_layout(layout),
is_signed=is_signed,
).foreach(
lambda _, idx: cast(idx[dimension]),
Expand Down

0 comments on commit f69592a

Please sign in to comment.