diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 3cc484afb4c..b1a5c4b0294 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -7,7 +7,7 @@ torch>=2.1.0 torchvision>=0.16.0 # Jax with cuda support. -jax[cuda12] +jax[cuda12]==0.4.29 flax -r requirements-common.txt