diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a332e72..79de096 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: shell: bash -l {0} run: | pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - name: Check nvidia drivers shell: bash -l {0} run: |