diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index 8452d1ee7264..509ef08a987f 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -41,6 +41,21 @@ floatx = dtypes.canonicalize_dtype(jnp.float64) +def _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res, ref +): + total_columns = col_grid_size * 128 + mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool) + + for i, r in enumerate(ragged_shape): + mask = mask.at[i, :, : r * 128].set(True) + + res_valid = jnp.where(mask, res, -1) + ref_valid = jnp.where(mask, ref, -1) + + np.testing.assert_allclose(res_valid, ref_valid) + + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -104,24 +119,16 @@ def invoke_kernel(x): axis_size=3, )(x) - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - - def correct(v): - return np.count_nonzero(v == jnp.sin(1.0)) - - for b, batch in enumerate(res): - ragged_val = ragged_shape[b] - for r, row in enumerate(batch): - row_total = ragged_val * 128 - self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}") + ref = jax.vmap( + jnp.sin, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) - self.assertEqual(correct(res), ragged_total) + _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res.data, ref.data + ) def test_vmap_jumble_over_add_kernel(self): if not jtu.test_device_matches(["tpu"]): @@ -156,36 +163,34 @@ def invoke_kernel(x, y): (8, col_grid_size * 128), dtype=jnp.float32 ), grid=(1, col_grid_size), - interpret=False, + interpret=self.INTERPRET, )(x, y) - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) + # We've had this test fail with data corruption due to multiple + # invocations, so we run it k times to make sure it's not setting up + # memory incorrectly for subsequent invocations. + for _ in range(4): + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - - def correct(v): - return np.count_nonzero(v == 2.0) - - for r, row in enumerate(res): - ragged_val = ragged_shape[r] - row_total = ragged_val * 128 * row_count - self.assertEqual(correct(row), row_total) - for col in row: - col_total = ragged_val * 128 - self.assertEqual(correct(col), col_total) - - self.assertEqual(np.count_nonzero(res == 2.0), ragged_total) + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + + ref = jax.vmap( + lambda x, y: x + y, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) + _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res, ref.data + ) def test_vmap_jumble_over_sin_kernel_grid_remapping(self): if not jtu.test_device_matches(["tpu"]): @@ -212,7 +217,7 @@ def invoke_kernel(x): out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), grid=(1, 5), - interpret=False, + interpret=self.INTERPRET, )(x) with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): @@ -280,7 +285,7 @@ def matmul( ), grid=grid, input_output_aliases={2: 0}, - interpret=False, + interpret=self.INTERPRET, )(x, y, x_sentinel) # TODO(mvoz): parameterize this shape?