diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 48e2bb5a1a6d..d501f926736c 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -481,7 +481,7 @@ jax_multiplatform_test( "gpu_a100_x32", "gpu_h100_x32", ], - shard_count = 2, + shard_count = 6, deps = [ "//jax:pallas", "//jax:pallas_gpu",