Skip to content

Commit

Permalink
[pallas] fix jumble test flakiness
Browse files Browse the repository at this point in the history
* Enable interpret mode in tests
* Ensure that the kernel is run multiple times where weve seen data corruption
* Use masked comparison - prior comparison was reading garbage data as we were basically relying on past behavior of how uninitialized memory was behaving.
* This was being hidden by a cache, where the interpret test, which always has 0.0 for uninitialized memory was being hit first, where TPU does not have the same behavior.

PiperOrigin-RevId: 703272002
  • Loading branch information
Google-ML-Automation committed Dec 5, 2024
1 parent 651ab18 commit 84f3f99
Showing 1 changed file with 51 additions and 46 deletions.
97 changes: 51 additions & 46 deletions tests/pallas/pallas_jumble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@
floatx = dtypes.canonicalize_dtype(jnp.float64)


def _assert_ragged_equal_with_elementwise_mask(
row_count, col_grid_size, ragged_shape, res, ref
):
total_columns = col_grid_size * 128
mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool)

for i, r in enumerate(ragged_shape):
mask = mask.at[i, :, : r * 128].set(True)

res_valid = jnp.where(mask, res, -1)
ref_valid = jnp.where(mask, ref, -1)

np.testing.assert_allclose(res_valid, ref_valid)


@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
Expand Down Expand Up @@ -104,24 +119,16 @@ def invoke_kernel(x):
axis_size=3,
)(x)

res = res.data
total = len(ragged_shape) * row_count * col_grid_size * 128
res_total = np.prod(res.shape)
self.assertEqual(res_total, total)
ragged_total = 0
for dim in ragged_shape:
ragged_total += row_count * dim * 128

def correct(v):
return np.count_nonzero(v == jnp.sin(1.0))

for b, batch in enumerate(res):
ragged_val = ragged_shape[b]
for r, row in enumerate(batch):
row_total = ragged_val * 128
self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}")
ref = jax.vmap(
jnp.sin,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x)

self.assertEqual(correct(res), ragged_total)
_assert_ragged_equal_with_elementwise_mask(
row_count, col_grid_size, ragged_shape, res.data, ref.data
)

def test_vmap_jumble_over_add_kernel(self):
if not jtu.test_device_matches(["tpu"]):
Expand Down Expand Up @@ -156,36 +163,34 @@ def invoke_kernel(x, y):
(8, col_grid_size * 128), dtype=jnp.float32
),
grid=(1, col_grid_size),
interpret=False,
interpret=self.INTERPRET,
)(x, y)

res = jax.vmap(
invoke_kernel,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x, y)
# We've had this test fail with data corruption due to multiple
# invocations, so we run it k times to make sure it's not setting up
# memory incorrectly for subsequent invocations.
for _ in range(4):
res = jax.vmap(
invoke_kernel,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x, y)

res = res.data
total = len(ragged_shape) * row_count * col_grid_size * 128
res_total = np.prod(res.shape)
self.assertEqual(res_total, total)
ragged_total = 0
for dim in ragged_shape:
ragged_total += row_count * dim * 128

def correct(v):
return np.count_nonzero(v == 2.0)

for r, row in enumerate(res):
ragged_val = ragged_shape[r]
row_total = ragged_val * 128 * row_count
self.assertEqual(correct(row), row_total)
for col in row:
col_total = ragged_val * 128
self.assertEqual(correct(col), col_total)

self.assertEqual(np.count_nonzero(res == 2.0), ragged_total)
res = res.data
total = len(ragged_shape) * row_count * col_grid_size * 128
res_total = np.prod(res.shape)
self.assertEqual(res_total, total)

ref = jax.vmap(
lambda x, y: x + y,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x, y)
_assert_ragged_equal_with_elementwise_mask(
row_count, col_grid_size, ragged_shape, res, ref.data
)

def test_vmap_jumble_over_sin_kernel_grid_remapping(self):
if not jtu.test_device_matches(["tpu"]):
Expand All @@ -212,7 +217,7 @@ def invoke_kernel(x):
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32),
grid=(1, 5),
interpret=False,
interpret=self.INTERPRET,
)(x)

with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"):
Expand Down Expand Up @@ -280,7 +285,7 @@ def matmul(
),
grid=grid,
input_output_aliases={2: 0},
interpret=False,
interpret=self.INTERPRET,
)(x, y, x_sentinel)

# TODO(mvoz): parameterize this shape?
Expand Down

0 comments on commit 84f3f99

Please sign in to comment.