Skip to content

Commit

Permalink
Populate --test_tag_filters through separate array variable
Browse files Browse the repository at this point in the history
  • Loading branch information
andportnoy committed Oct 11, 2024
1 parent cdf89c9 commit a2666ca
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ for t in $*; do
BAZEL_TARGET="${BAZEL_TARGET} $t"
done

TEST_TAG_FILTER_ARRAY=()
TEST_TAG_FILTER_ARRAY+=('-multiaccelerator')

COMMON_FLAGS=$(cat << EOF
--@local_config_cuda//:enable_cuda
--cache_test_results=${CACHE_TEST_RESULTS}
--test_timeout=600
--test_tag_filters=-multiaccelerator
--test_env=JAX_SKIP_SLOW_TESTS=1
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
Expand All @@ -141,7 +143,8 @@ case "${BATTERY}" in
# collect from all tests subdirectories recursively,
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
BAZEL_TARGET="${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu"
TEST_TAG_FILTER_ARRAY+=('jax_test_gpu')
BAZEL_TARGET="${BAZEL_TARGET} //tests/...
;;
backend-independent)
JOBS_PER_GPU=4
Expand All @@ -160,6 +163,8 @@ case "${BATTERY}" in
;;
esac
TEST_TAG_FILTERS=$(IFS=, echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")
print_var NCPUS
print_var NGPUS
print_var BATTERY
Expand All @@ -168,6 +173,7 @@ print_var JOBS_PER_GPU
print_var JOBS
print_var BUILD_JAXLIB
print_var BAZEL_TARGET
print_var TEST_TAG_FILTERS
print_var COMMON_FLAGS
print_var EXTRA_FLAGS
Expand All @@ -185,4 +191,4 @@ pip install matplotlib
cd `jax_source_dir`
python build/build.py --configure_only
bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}

0 comments on commit a2666ca

Please sign in to comment.