Skip to content

Cloud TPU nightly

Cloud TPU nightly #641

name: Cloud TPU nightly
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
workflow_dispatch: # allows triggering the workflow run manually
jobs:
cloud-tpu-test:
runs-on: v4-8
defaults:
run:
shell: bash -l {0}
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
jaxlib-version: [latest, nightly]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install -r build/test-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install .
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install libtpu-nightly \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
run: python -m pytest --tb=short tests examples