diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 6afbaace1..880c3cabd 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -138,7 +138,10 @@ case "${BATTERY}" in JOBS_PER_GPU=8 JOBS=$((NGPUS * JOBS_PER_GPU)) EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow" - BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests" + # collect from all tests subdirectories recursively, + # use jax_test_gpu tag generated by jax_multiplatform_test rule: + # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265 + BAZEL_TARGET="${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu" ;; backend-independent) JOBS_PER_GPU=4