Skip to content

Commit

Permalink
Collect recursively and filter GPU tests using jax_test_gpu tag
Browse files Browse the repository at this point in the history
  • Loading branch information
andportnoy committed Oct 11, 2024
1 parent e8043a5 commit cdf89c9
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ case "${BATTERY}" in
JOBS_PER_GPU=8
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
# collect from all tests subdirectories recursively,
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
BAZEL_TARGET="${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu"
;;
backend-independent)
JOBS_PER_GPU=4
Expand Down

0 comments on commit cdf89c9

Please sign in to comment.