diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml new file mode 100644 index 000000000000..718725995bfb --- /dev/null +++ b/.github/workflows/build_artifacts.yml @@ -0,0 +1,148 @@ +# CI - Build JAX Artifacts +# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) and optionally +# uploads them to a Google Cloud Storage (GCS) bucket. It is also resusable through a workflow call +# and is used by other CI workflows such as the Pytest workflows for building the artifacts. +name: CI - Build JAX Artifacts + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + required: true + default: "linux-x86-n2-16" + options: + - "linux-x86-n2-16" + - "linux-arm64-c4a-64" + - "windows-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" + type: choice + required: true + default: "jaxlib" + options: + - "jax" + - "jaxlib" + - "jax-cuda-plugin" + - "jax-cuda-pjrt" + python: + description: "Which python version should the artifact be built for?" + type: choice + required: false + default: "3.12" + options: + - "3.10" + - "3.11" + - "3.12" + - "3.13" + clone_main_xla: + description: "Should latest XLA be used?" + type: choice + required: false + default: "0" + options: + - "1" + - "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: false + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" + type: string + required: true + default: "jaxlib" + python: + description: "Which python version should the artifact be built for?" + type: string + required: false + default: "3.12" + clone_main_xla: + description: "Should latest XLA be used?" + type: string + required: false + default: "0" + upload_artifacts: + description: "Should the artifacts be uploaded to a GCS bucket?" + required: false + default: false + type: boolean + upload_destination_prefix: + description: "GCS location prefix to where the artifacts should be uploaded" + required: false + default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + +permissions: + contents: read + +jobs: + build-artifacts: + defaults: + run: + # Explicitly set the shell to bash to override Windows's default (cmd) + shell: bash + + runs-on: ${{ inputs.runner }} + + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" + + name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') + run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + + - name: Build ${{ inputs.artifact }} + run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" + + - name: Set PLATFORM env var for use in artifact upload URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + + - name: Upload artifacts to a GCS bucket (non-Windows runs) + if: >- + ${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} + run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ + + # Set shell to cmd to avoid path errors when using gcloud commands on Windows + - name: Upload artifacts to a GCS bucket (Windows runs) + if: >- + ${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} + shell: cmd + run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml new file mode 100644 index 000000000000..d6e240366a02 --- /dev/null +++ b/.github/workflows/pytest_cpu.yml @@ -0,0 +1,111 @@ +# CI - Pytest CPU +# +# This workflow builds jaxlib artifact and then runs the CPU tests with Pytest. +# +# It consists of two jobs: +# 1. build-artifacts: +# - This job uses a reusable workflow (`build_artifacts.yml`) to build the necessary jaxlib wheel +# - The built wheel is then uploaded to a Google Cloud Storage (GCS) bucket for later use. +# 2. run-tests: +# - This job downloads the jaxlib wheel built in the `build-artifacts` job from the GCS bucket. +# - It then executes the `run_pytest_cpu.sh` script, which performs the following actions: +# - Installs the downloaded jaxlib wheel. +# - Runs the CPU tests with Pytest. +name: CI - Pytest CPU + +on: + # TODO: For testing purposes, remove before submitting + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: false + default: 'no' + options: + - 'yes' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build-artifacts: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.10", "3.13"] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run-tests: + needs: build-artifacts + defaults: + run: + # Explicitly set the shell to bash to override Windows's default (cmd) + shell: bash + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] + python: ["3.10", "3.13"] + enable-x_64: [1, 0] + + runs-on: ${{ matrix.runner }} + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(matrix.runner, 'windows-x86') && null) }} + + name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Whether to enable/disabe x64 mode + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set Platform env var for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download artifacts that were built in the "build-artifacts" job (non-Windows runs) + if: ${{ !contains(matrix.runner, 'windows-x86') }} + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Download artifacts that were built in the "build-artifacts" job (Windows runs) + if: ${{ contains(matrix.runner, 'windows-x86') }} + shell: cmd + run: >- + mkdir dist && + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ + - name: Install dependencies + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CPU tests + run: ./ci/run_pytest_cpu.sh diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml new file mode 100644 index 000000000000..4a17b22b10b5 --- /dev/null +++ b/.github/workflows/pytest_cuda.yml @@ -0,0 +1,99 @@ +# CI - Pytest CUDA +# +# This workflow builds jaxlib + CUDA artifacts and then runs the CUDA tests with Pytest. +# +# It consists of two jobs: +# 1. build-artifacts: +# - This job uses a reusable workflow (`build_artifacts.yml`) to build the necessary JAX wheels +# (jaxlib, jax-cuda-plugin, and jax-cuda-pjrt). +# - The built wheels are then uploaded to a Google Cloud Storage (GCS) bucket for later use. +# 2. run-tests: +# - This job downloads the JAX wheels built in the `build-artifacts` job from the GCS bucket. +# - It then executes the `run_pytest_cuda.sh` script, which performs the following actions: +# - Installs the downloaded JAX wheels. +# - Runs the CUDA tests with Pytest. +name: CI - Pytest CUDA + +on: + # TODO: For testing purposes, remove before submitting + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: false + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build-artifacts: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16"] + artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] + python: ["3.10", "3.13"] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run-tests: + needs: build-artifacts + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + test_env: [ + {cuda_version: "12.3", + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + {cuda_version: "12.1", + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + ] + runner: ["linux-x86-g2-48-l4-4gpu"] + python: ["3.10", "3.13"] + enable-x_64: [1, 0] + + runs-on: ${{ matrix.runner }} + container: + image: ${{ matrix.test_env.image }} + + name: "Pytest CUDA (${{ matrix.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Whether to enable/disabe x64 mode + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set Platform env var for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download artifacts that were built in the "build-artifacts" job + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Install dependencies + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CUDA tests + run: ./ci/run_pytest_cuda.sh \ No newline at end of file diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh old mode 100644 new mode 100755 diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 2b19ca5ddaa5..0b045bdc7927 100644 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -39,6 +39,7 @@ source "ci/utilities/setup_build_environment.sh" export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_64="$JAXCI_ENABLE_X64" # End of test environment variable setup echo "Running CPU tests..." diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_cuda.sh similarity index 94% rename from ci/run_pytest_gpu.sh rename to ci/run_pytest_cuda.sh index 7bc2492781b2..e6dc3c18dead 100644 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_cuda.sh @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt +# Runs Pyest CUDA tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt # wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) # # -e: abort script if one command fails @@ -43,6 +43,7 @@ export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_64="$JAXCI_ENABLE_X64" # Set the number of processes to run to be 4x the number of GPUs. export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) @@ -52,7 +53,7 @@ export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 # End of test environment variable setup -echo "Running GPU tests..." +echo "Running CUDA tests..." "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 181256b90804..4af679c9d079 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,7 +26,13 @@ fi echo "Installing the following wheels:" echo "${WHEELS[@]}" -"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" + +# On Windows, convert MSYS Linux-like paths to Windows paths. +if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m pip install $(cygpath -w "${WHEELS[@]}") +else + "$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" +fi echo "Installing the JAX package in editable mode at the current commit..." # Install JAX package at the current commit. diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh old mode 100644 new mode 100755