Skip to content

Add Github action workflows for running continuous tests with Pytest #2

Add Github action workflows for running continuous tests with Pytest

Add Github action workflows for running continuous tests with Pytest #2

Workflow file for this run

# CI - Pytest CUDA
#
# This workflow builds jaxlib + CUDA artifacts and then runs the CUDA tests with Pytest.
#
# It consists of two jobs:
# 1. build-artifacts:
# - This job uses a reusable workflow (`build_artifacts.yml`) to build the necessary JAX wheels
# (jaxlib, jax-cuda-plugin, and jax-cuda-pjrt).
# - The built wheels are then uploaded to a Google Cloud Storage (GCS) bucket for later use.
# 2. run-tests:
# - This job downloads the JAX wheels built in the `build-artifacts` job from the GCS bucket.
# - It then executes the `run_pytest_cuda.sh` script, which performs the following actions:
# - Installs the downloaded JAX wheels.
# - Runs the CUDA tests with Pytest.
name: CI - Pytest CUDA
on:
# TODO: For testing purposes, remove before submitting
pull_request:
branches:
- main
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: false
default: 'no'
options:
- 'yes'
- 'no'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-artifacts:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ["linux-x86-n2-16"]
artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.10", "3.13"]
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
clone_main_xla: 1
upload_artifacts: true
upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
run-tests:
needs: build-artifacts
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
test_env: [
{cuda_version: "12.3",
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
{cuda_version: "12.1",
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
]
runner: ["linux-x86-g2-48-l4-4gpu"]
python: ["3.10", "3.13"]
enable-x_64: [1, 0]
runs-on: ${{ matrix.runner }}
container:
image: ${{ matrix.test_env.image }}
name: "Pytest CUDA (${{ matrix.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Whether to enable/disabe x64 mode
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Set Platform env var for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV
- name: Download artifacts that were built in the "build-artifacts" job
run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/
- name: Install dependencies
env:
JAXCI_PYTHON: python${{ matrix.python }}
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
- name: Run Pytest CUDA tests
run: ./ci/run_pytest_cuda.sh