Skip to content

Commit

Permalink
Ensure that the two offsets of a dynamic_slice have the same dtype re…
Browse files Browse the repository at this point in the history
…gardless

the value of config.enable_x64.

PiperOrigin-RevId: 708031525
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Dec 19, 2024
1 parent de8fa8f commit 16712b5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,8 +1347,10 @@ def partition(precision, mesh, arg_shapes, result_shape):
def lower_fn(x, y):
axis_name = arg_shardings[1].spec[0][0]
i = jax.lax.axis_index(axis_name)
# Use offset i * 0 instead of 0 to ensure that the two offsets have the
# same dtype regardless the value of config.enable_x64.
z = jax.lax.psum(
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name)
jax.lax.dynamic_slice(x, (i * 0, i * 8), (8, 8)) @ y, (axis_name)
)
return z, z * z

Expand Down

0 comments on commit 16712b5

Please sign in to comment.