Skip to content

Commit

Permalink
use pip for jax
Browse files Browse the repository at this point in the history
  • Loading branch information
kp992 committed May 27, 2024
1 parent 4ad93be commit 95dd642
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ jobs:
- name: Install CUDA
shell: bash -l {0}
run: |
conda install anaconda::cudatoolkit -y
conda install cudatoolkit cuda-nvcc -c nvidia -c anaconda -y
- name: Install JAX[CUDA] and Numpyro[CUDA]
shell: bash -l {0}
run: |
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia -y
conda install conda-forge::numpyro -y
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "numpyro[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- name: Check nvidia drivers
shell: bash -l {0}
run: |
Expand Down

0 comments on commit 95dd642

Please sign in to comment.