Nightly JAX CI on NVIDIA GPUs #643
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Nightly JAX CI on NVIDIA GPUs | |
# This configures running JAX tests (especially multi-node multi-gpu) against nightly GPU jaxlib builds. | |
# This is expected to fail frequently, and so we don't run it against every commit and PR in the repository. | |
# Portions of this adapted from https://github.com/google/jax/blob/main/.github/workflows/upstream-nightly.yaml | |
# Controls when the workflow will run | |
on: | |
schedule: | |
- cron: "0 12 * * *" # Daily at 12:00 UTC | |
workflow_dispatch: # allows triggering the workflow run manually | |
pull_request: # Automatically trigger on pull requests affecting this file | |
branches: | |
- main | |
paths: | |
- '**workflows/nightly-ci-multiprocess-gpu.yml' | |
jobs: | |
jaxlib-nightly: | |
runs-on: self-hosted | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Launch slurm job and hook output to this shell | |
run: | | |
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts | |
source $JOBSCRIPTSDIR/slurm_utils_common.sh | |
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest_jaxlib_nightly.sub | tee output.log | |
sleep 2m | |
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }') | |
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }') | |
job_wait "${SLURM_JOBID}" & PID=$! | |
touch "${SLURM_OUTPUT}" | |
echo -e " ---------------------------------------------------\n" \ | |
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \ | |
"---------------------------------------------------\n" | |
tail --pid="${PID}" -f "${SLURM_OUTPUT}" | |
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'" | |
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs" | |
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'" | |
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi | |
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi | |
- name: Publish Test Results | |
uses: EnricoMi/publish-unit-test-result-action@v2 | |
if: always() | |
with: | |
junit_files: "outputs/*.xml" | |
- name: Upload run results from all nodes | |
uses: actions/upload-artifact@v3 | |
if: always() | |
with: | |
name: output-from-nodes | |
path: "outputs/*.txt" | |
jaxlib-release: | |
runs-on: self-hosted | |
needs: jaxlib-nightly | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Launch slurm job and hook output to this shell | |
run: | | |
export JOBSCRIPTSDIR=${GITHUB_WORKSPACE}/.github/workflows/slurm_job_scripts | |
source $JOBSCRIPTSDIR/slurm_utils_common.sh | |
sbatch -N 2 $JOBSCRIPTSDIR/multinode_pytest_jaxlib_release.sub | tee output.log | |
sleep 2m | |
export SLURM_JOBID=$(grep 'Submitted batch job' "output.log" | awk '{ print $4 }') | |
export SLURM_OUTPUT=$(scontrol show job "${SLURM_JOBID}" | grep 'StdOut' | awk -F '=' '{ print $2 }') | |
job_wait "${SLURM_JOBID}" & PID=$! | |
touch "${SLURM_OUTPUT}" | |
echo -e " ---------------------------------------------------\n" \ | |
"----------WAITING FOR SLURM JOB TO BEGIN-----------\n" \ | |
"---------------------------------------------------\n" | |
tail --pid="${PID}" -f "${SLURM_OUTPUT}" | |
export SLURM_STATE=$(job_state "${SLURM_JOBID}"); echo "SLURM_JOBID=${SLURM_JOBID} SLURM_STATE='${SLURM_STATE}'" | |
export SLURM_WALLTIME=$(job_time "${SLURM_JOBID}"); echo "SLURM_WALLTIME=${SLURM_WALLTIME} secs" | |
export SLURM_EXITCODE=$(job_exit_code "${SLURM_JOBID}" || echo $?); echo "SLURM_EXITCODE='${SLURM_EXITCODE}'" | |
if [ "${SLURM_EXITCODE}" != "0" ]; then exit ${SLURM_EXITCODE:-999}; fi | |
if [ "${SLURM_STATE}" != "COMPLETED" ]; then exit 1; fi | |
- name: Publish Test Results | |
uses: EnricoMi/publish-unit-test-result-action@v2 | |
if: always() | |
with: | |
junit_files: "outputs/*.xml" | |
- name: Upload run results from all nodes | |
uses: actions/upload-artifact@v3 | |
if: always() | |
with: | |
name: output-from-nodes | |
path: "outputs/*.txt" | |
report: | |
name: report | |
needs: [jaxlib-nightly, jaxlib-release] | |
if: | | |
failure() | |
&& github.event_name == 'schedule' | |
runs-on: ubuntu-latest | |
defaults: | |
run: | |
shell: bash | |
steps: | |
- uses: actions/checkout@v3 | |
- uses: actions/setup-python@v4 | |
with: | |
python-version: "3.x" | |
- uses: actions/download-artifact@v3 | |
with: | |
path: /tmp/workspace/logs | |
- name: Parse log output | |
run: | | |
ls /tmp/workspace/logs/output-from-nodes/ | |
python .github/workflows/cat_slurm_logs.py /tmp/workspace/logs/output-from-nodes/*.txt --outfile=parsed-logs.txt | |
- name: Report failures | |
uses: actions/github-script@v6 | |
with: | |
github-token: ${{ secrets.GITHUB_TOKEN }} | |
script: | | |
const fs = require('fs'); | |
const parsed_logs = fs.readFileSync('parsed-logs.txt', 'utf8'); | |
const title = "⚠️ Nightly GPU Multiprocess CI failed ⚠️" | |
const workflow_url = `https://github.com/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}` | |
const issue_body = `[Workflow Run URL](${workflow_url})\n${parsed_logs}` | |
// Run GraphQL query against GitHub API to find the most recent open issue used for reporting failures | |
const query = `query($owner:String!, $name:String!, $creator:String!, $label:String!){ | |
repository(owner: $owner, name: $name) { | |
issues(first: 1, states: OPEN, filterBy: {createdBy: $creator, labels: [$label]}, orderBy: {field: CREATED_AT, direction: DESC}) { | |
edges { | |
node { | |
body | |
id | |
number | |
} | |
} | |
} | |
} | |
}`; | |
const variables = { | |
owner: context.repo.owner, | |
name: context.repo.repo, | |
label: 'Nightly-CI', | |
creator: "github-actions[bot]" | |
} | |
const result = await github.graphql(query, variables) | |
// If no issue is open, create a new issue, | |
// else update the body of the existing issue. | |
if (result.repository.issues.edges.length === 0) { | |
github.rest.issues.create({ | |
owner: variables.owner, | |
repo: variables.name, | |
body: issue_body, | |
title: title, | |
labels: [variables.label] | |
}) | |
} else { | |
github.rest.issues.update({ | |
owner: variables.owner, | |
repo: variables.name, | |
issue_number: result.repository.issues.edges[0].node.number, | |
body: issue_body | |
}) | |
} |