-
Notifications
You must be signed in to change notification settings - Fork 2.9k
138 lines (131 loc) · 6.88 KB
/
cloud-tpu-ci-nightly.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Cloud TPU CI
#
# This job currently runs once per day. We use self-hosted TPU runners, so we'd
# have to add more runners to run on every commit.
#
# This job's build matrix runs over several TPU architectures using both the
# latest released jaxlib on PyPi ("pypi_latest") and the latest nightly
# jaxlib.("nightly"). It also installs a matching libtpu, either the one pinned
# to the release for "pypi_latest", or the latest nightly.for "nightly". It
# always locally installs jax from github head (already checked out by the
# Github Actions environment).
name: CI - Cloud TPU (nightly)
on:
schedule:
- cron: "0 2,14 * * *" # Run at 7am and 7pm PST
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read
jobs:
cloud-tpu-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240922
PYTHON: python${{ matrix.python-version }}
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 180
defaults:
run:
shell: bash -ex {0}
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
if: ${{ matrix.jaxlib-version == 'head' }}
with:
repository: openxla/xla
path: xla
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
- name: Mark GitHub workspace as safe
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$PYTHON -m pip install -U -r build/test-requirements.txt
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
$PYTHON -m pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
# Build and install jaxlib at head
$PYTHON build/build.py build --wheels=jaxlib \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose
$PYTHON -m pip install dist/*.whl
# Install "jax" at head
$PYTHON -m pip install -U -e .
# Install libtpu
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
$PYTHON -m pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$PYTHON -m pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$PYTHON -m pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
$PYTHON -c 'import sys; print("python version:", sys.version)'
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
$PYTHON -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic TPU does not
# guarantee anything about forward compatibility (unless jax.export is used) and the 12
# week compatibility window accumulates way too many failures.
IGNORE_FLAGS=
if [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
IGNORE_FLAGS="--ignore=tests/pallas"
fi
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
run: |
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"