Introduce GPU testing to JAX Triton #224
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: ci | |
on: | |
push: | |
branches: | |
- main | |
pull_request: | |
branches: | |
- main | |
permissions: | |
contents: read # to fetch code | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
lint: | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 | |
- name: Set up Python 3.10 | |
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/[email protected] | |
test: | |
runs-on: linux-x86-g2-48-l4-4gpu | |
container: | |
# TODO: change image based on what is needed for these tests | |
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11 | |
steps: | |
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 | |
- name: Setup Compat Driver | |
run: | | |
# This container should already have the CUDA apt repos setup | |
apt-get update | |
apt-get install -y --no-install-recommends cuda-compat-12-6 | |
- name: Setup Released JAX | |
run: | | |
pip install -U "jax[cuda12]" | |
pip install pytest | |
- name: Test JAX Triton | |
run: | | |
echo "Running JAX Triton GPU Tests" | |
nvidia-smi | |
pip install . | |
# Need newer ml-dtypes because we install newer numpy | |
pip install --upgrade ml-dtypes | |
pytest -v --tb=short tests/ | |