Add Github action workflows for running continuous tests with Pytest #3
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CI - Pytest CUDA | |
# | |
# This workflow builds jaxlib + CUDA artifacts and then runs the CUDA tests with Pytest. | |
# | |
# It consists of two jobs: | |
# 1. build-artifacts: | |
# - This job uses a reusable workflow (`build_artifacts.yml`) to build the necessary JAX wheels | |
# (jaxlib, jax-cuda-plugin, and jax-cuda-pjrt). | |
# - The built wheels are then uploaded to a Google Cloud Storage (GCS) bucket for later use. | |
# 2. run-tests: | |
# - This job downloads the JAX wheels built in the `build-artifacts` job from the GCS bucket. | |
# - It then executes the `run_pytest_cuda.sh` script, which performs the following actions: | |
# - Installs the downloaded JAX wheels. | |
# - Runs the CUDA tests with Pytest. | |
name: CI - Pytest CUDA | |
on: | |
# TODO: For testing purposes, remove before submitting | |
pull_request: | |
branches: | |
- main | |
workflow_dispatch: | |
inputs: | |
halt-for-connection: | |
description: 'Should this workflow run wait for a remote connection?' | |
type: choice | |
required: false | |
default: 'no' | |
options: | |
- 'yes' | |
- 'no' | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
build-artifacts: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
runner: ["linux-x86-n2-16"] | |
artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] | |
python: ["3.10",] # "3.13"] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts: true | |
upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
run-tests: | |
needs: build-artifacts | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
test_env: [ | |
{cuda_version: "12.3", | |
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, | |
{cuda_version: "12.1", | |
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, | |
] | |
runner: ["linux-x86-g2-48-l4-4gpu"] | |
python: ["3.10",] # "3.13"] | |
enable-x_64: [1, 0] | |
runs-on: ${{ matrix.runner }} | |
container: | |
image: ${{ matrix.test_env.image }} | |
name: "Pytest CUDA (${{ matrix.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" | |
env: | |
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} | |
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Whether to enable/disabe x64 mode | |
steps: | |
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
# Halt for testing | |
- name: Wait For Connection | |
uses: google-ml-infra/actions/ci_connection@main | |
with: | |
halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
- name: Set Platform env var for use in artifact download URL | |
run: | | |
os=$(uname -s | awk '{print tolower($0)}') | |
arch=$(uname -m) | |
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV | |
- name: Download artifacts that were built in the "build-artifacts" job | |
run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ | |
- name: Install dependencies | |
env: | |
JAXCI_PYTHON: python${{ matrix.python }} | |
run: $JAXCI_PYTHON -m pip install -r build/requirements.in | |
- name: Run Pytest CUDA tests | |
run: ./ci/run_pytest_cuda.sh |