Skip to content

Commit

Permalink
Fix jax cuda requirements (#20295)
Browse files Browse the repository at this point in the history
* fix jax cuda requirements

* fix jax cuda requirements

* fix jax cuda requirements

* fix jax cuda requirements

* fix jax cuda requirements
  • Loading branch information
sampathweb authored Sep 27, 2024
1 parent 93cb954 commit 3fb091f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions keras/src/backend/tests/device_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ def test_jax_device_scope(self):

with backend.device_scope("cpu:0"):
t = backend.numpy.ones((2, 1))
self.assertEqual(t.device(), jax.devices("cpu")[0])
self.assertEqual(t.device, jax.devices("cpu")[0])
with backend.device_scope("CPU:0"):
t = backend.numpy.ones((2, 1))
self.assertEqual(t.device(), jax.devices("cpu")[0])
self.assertEqual(t.device, jax.devices("cpu")[0])

# When leaving the scope, the device should be back with gpu:0
t = backend.numpy.ones((2, 1))
self.assertEqual(t.device(), jax.devices("gpu")[0])
self.assertEqual(t.device, jax.devices("gpu")[0])

# Also verify the explicit gpu device
with backend.device_scope("gpu:0"):
t = backend.numpy.ones((2, 1))
self.assertEqual(t.device(), jax.devices("gpu")[0])
self.assertEqual(t.device, jax.devices("gpu")[0])

@pytest.mark.skipif(backend.backend() != "jax", reason="jax only")
def test_invalid_jax_device(self):
Expand Down
4 changes: 3 additions & 1 deletion requirements-jax-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ torch>=2.1.0
torchvision>=0.16.0

# Jax with cuda support.
jax[cuda12_pip]
# TODO: Higher version breaks CI.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12]==0.4.28
flax

-r requirements-common.txt

0 comments on commit 3fb091f

Please sign in to comment.