diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index b1a5c4b0294..034dd90956c 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -7,7 +7,7 @@ torch>=2.1.0 torchvision>=0.16.0 # Jax with cuda support. -jax[cuda12]==0.4.29 +jax[cuda12]==0.4.28 flax -r requirements-common.txt