diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 85940b4ee975..74245b6d9f5b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2801,16 +2801,15 @@ def _device_id_to_logical( # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides - def _linearize_mesh_indices(*indices): - return sum(a * b for a, b in zip(indices, mesh_strides)) - lower_ctx = LoweringRuleContext( - lowering_context=ctx.lowering_context, - avals_in=[pallas_core.index_map_grid_aval] * len(device_ids), - avals_out=[pallas_core.index_map_grid_aval], - block_shapes=(None,) * len(device_ids), + + i32 = ir.IntegerType.get_signless(32) + return functools.reduce( + arith.addi, + ( + arith.muli(a, arith.constant(i32, b)) + for a, b in zip(device_ids, mesh_strides) + ), ) - return lower_fun(_linearize_mesh_indices, multiple_results=False)( - lower_ctx, *device_ids) elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}")