-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Github action workflows for running continuous tests with Pytest
PiperOrigin-RevId: 702497163
- Loading branch information
1 parent
d4031e9
commit d2f5209
Showing
8 changed files
with
369 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# CI - Build JAX Artifacts | ||
# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) and optionally | ||
# uploads them to a Google Cloud Storage (GCS) bucket. It is also resusable through a workflow call | ||
# and is used by other CI workflows such as the Pytest workflows for building the artifacts. | ||
name: CI - Build JAX Artifacts | ||
|
||
on: | ||
workflow_dispatch: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: choice | ||
required: true | ||
default: "linux-x86-n2-16" | ||
options: | ||
- "linux-x86-n2-16" | ||
- "linux-arm64-c4a-64" | ||
- "windows-x86-n2-16" | ||
artifact: | ||
description: "Which JAX artifact to build?" | ||
type: choice | ||
required: true | ||
default: "jaxlib" | ||
options: | ||
- "jax" | ||
- "jaxlib" | ||
- "jax-cuda-plugin" | ||
- "jax-cuda-pjrt" | ||
python: | ||
description: "Which python version should the artifact be built for?" | ||
type: choice | ||
required: false | ||
default: "3.12" | ||
options: | ||
- "3.10" | ||
- "3.11" | ||
- "3.12" | ||
- "3.13" | ||
clone_main_xla: | ||
description: "Should latest XLA be used?" | ||
type: choice | ||
required: false | ||
default: "0" | ||
options: | ||
- "1" | ||
- "0" | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: false | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
workflow_call: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: string | ||
required: true | ||
default: "linux-x86-n2-16" | ||
artifact: | ||
description: "Which JAX artifact to build?" | ||
type: string | ||
required: true | ||
default: "jaxlib" | ||
python: | ||
description: "Which python version should the artifact be built for?" | ||
type: string | ||
required: false | ||
default: "3.12" | ||
clone_main_xla: | ||
description: "Should latest XLA be used?" | ||
type: string | ||
required: false | ||
default: "0" | ||
upload_artifacts: | ||
description: "Should the artifacts be uploaded to a GCS bucket?" | ||
required: false | ||
default: false | ||
type: boolean | ||
upload_destination_prefix: | ||
description: "GCS location prefix to where the artifacts should be uploaded" | ||
required: false | ||
default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | ||
type: string | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build-artifacts: | ||
defaults: | ||
run: | ||
# Explicitly set the shell to bash to override Windows's default (cmd) | ||
shell: bash | ||
|
||
runs-on: ${{ inputs.runner }} | ||
|
||
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | ||
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | ||
(contains(inputs.runner, 'windows-x86') && null) }} | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" | ||
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" | ||
|
||
name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
|
||
- name: Enable RBE if building on Linux x86 or Windows x86 | ||
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') | ||
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV | ||
|
||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
|
||
- name: Build ${{ inputs.artifact }} | ||
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" | ||
|
||
- name: Set PLATFORM env var for use in artifact upload URL | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
# Adjust name for Windows | ||
if [[ $os =~ "msys_nt" ]]; then | ||
os="windows" | ||
fi | ||
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV | ||
- name: Upload artifacts to a GCS bucket (non-Windows runs) | ||
if: >- | ||
${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} | ||
run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ | ||
|
||
# Set shell to cmd to avoid path errors when using gcloud commands on Windows | ||
- name: Upload artifacts to a GCS bucket (Windows runs) | ||
if: >- | ||
${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} | ||
shell: cmd | ||
run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# CI - Pytest CPU | ||
# | ||
# This workflow builds jaxlib artifact and then runs the CPU tests with Pytest. | ||
# | ||
# It consists of two jobs: | ||
# 1. build-artifacts: | ||
# - This job uses a reusable workflow (`build_artifacts.yml`) to build the necessary jaxlib wheel | ||
# - The built wheel is then uploaded to a Google Cloud Storage (GCS) bucket for later use. | ||
# 2. run-tests: | ||
# - This job downloads the jaxlib wheel built in the `build-artifacts` job from the GCS bucket. | ||
# - It then executes the `run_pytest_cpu.sh` script, which performs the following actions: | ||
# - Installs the downloaded jaxlib wheel. | ||
# - Runs the CPU tests with Pytest. | ||
name: CI - Pytest CPU | ||
|
||
on: | ||
# TODO: For testing purposes, remove before submitting | ||
pull_request: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
inputs: | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: false | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build-artifacts: | ||
uses: ./.github/workflows/build_artifacts.yml | ||
strategy: | ||
fail-fast: false # don't cancel all jobs on failure | ||
matrix: | ||
runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] | ||
artifact: ["jaxlib"] | ||
python: ["3.10", "3.13"] | ||
with: | ||
runner: ${{ matrix.runner }} | ||
artifact: ${{ matrix.artifact }} | ||
python: ${{ matrix.python }} | ||
clone_main_xla: 1 | ||
upload_artifacts: true | ||
upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | ||
|
||
run-tests: | ||
needs: build-artifacts | ||
defaults: | ||
run: | ||
# Explicitly set the shell to bash to override Windows's default (cmd) | ||
shell: bash | ||
|
||
strategy: | ||
fail-fast: false # don't cancel all jobs on failure | ||
matrix: | ||
runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] | ||
python: ["3.10", "3.13"] | ||
enable-x_64: [1, 0] | ||
|
||
runs-on: ${{ matrix.runner }} | ||
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | ||
(contains(matrix.runner, 'windows-x86') && null) }} | ||
|
||
name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} | ||
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Whether to enable/disabe x64 mode | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Set Platform env var for use in artifact download URL | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
# Adjust name for Windows | ||
if [[ $os =~ "msys_nt" ]]; then | ||
os="windows" | ||
fi | ||
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV | ||
- name: Download artifacts that were built in the "build-artifacts" job (non-Windows runs) | ||
if: ${{ !contains(matrix.runner, 'windows-x86') }} | ||
run: >- | ||
mkdir -p $(pwd)/dist && | ||
gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ | ||
- name: Download artifacts that were built in the "build-artifacts" job (Windows runs) | ||
if: ${{ contains(matrix.runner, 'windows-x86') }} | ||
shell: cmd | ||
run: >- | ||
mkdir dist && | ||
gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ | ||
- name: Install dependencies | ||
env: | ||
JAXCI_PYTHON: python${{ matrix.python }} | ||
run: $JAXCI_PYTHON -m pip install -r build/requirements.in | ||
- name: Run Pytest CPU tests | ||
run: ./ci/run_pytest_cpu.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# CI - Pytest CUDA | ||
# | ||
# This workflow builds jaxlib + CUDA artifacts and then runs the CUDA tests with Pytest. | ||
# | ||
# It consists of two jobs: | ||
# 1. build-artifacts: | ||
# - This job uses a reusable workflow (`build_artifacts.yml`) to build the necessary JAX wheels | ||
# (jaxlib, jax-cuda-plugin, and jax-cuda-pjrt). | ||
# - The built wheels are then uploaded to a Google Cloud Storage (GCS) bucket for later use. | ||
# 2. run-tests: | ||
# - This job downloads the JAX wheels built in the `build-artifacts` job from the GCS bucket. | ||
# - It then executes the `run_pytest_cuda.sh` script, which performs the following actions: | ||
# - Installs the downloaded JAX wheels. | ||
# - Runs the CUDA tests with Pytest. | ||
name: CI - Pytest CUDA | ||
|
||
on: | ||
# TODO: For testing purposes, remove before submitting | ||
pull_request: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
inputs: | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: false | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build-artifacts: | ||
uses: ./.github/workflows/build_artifacts.yml | ||
strategy: | ||
fail-fast: false # don't cancel all jobs on failure | ||
matrix: | ||
runner: ["linux-x86-n2-16"] | ||
artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] | ||
python: ["3.10", "3.13"] | ||
with: | ||
runner: ${{ matrix.runner }} | ||
artifact: ${{ matrix.artifact }} | ||
python: ${{ matrix.python }} | ||
clone_main_xla: 1 | ||
upload_artifacts: true | ||
upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | ||
|
||
run-tests: | ||
needs: build-artifacts | ||
|
||
strategy: | ||
fail-fast: false # don't cancel all jobs on failure | ||
matrix: | ||
test_env: [ | ||
{cuda_version: "12.3", | ||
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, | ||
{cuda_version: "12.1", | ||
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, | ||
] | ||
runner: ["linux-x86-g2-48-l4-4gpu"] | ||
python: ["3.10", "3.13"] | ||
enable-x_64: [1, 0] | ||
|
||
runs-on: ${{ matrix.runner }} | ||
container: | ||
image: ${{ matrix.test_env.image }} | ||
|
||
name: "Pytest CUDA (${{ matrix.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} | ||
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Whether to enable/disabe x64 mode | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Set Platform env var for use in artifact download URL | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV | ||
- name: Download artifacts that were built in the "build-artifacts" job | ||
run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ | ||
- name: Install dependencies | ||
env: | ||
JAXCI_PYTHON: python${{ matrix.python }} | ||
run: $JAXCI_PYTHON -m pip install -r build/requirements.in | ||
- name: Run Pytest CUDA tests | ||
run: ./ci/run_pytest_cuda.sh |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.