diff --git a/tests/lax_test.py b/tests/lax_test.py index 79ad9fcfa02c..f2ce0913e03a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -41,7 +41,6 @@ from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util @@ -1077,9 +1076,6 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): if jtu.dtypes.supported([dtype]) ]) def testDotAlgorithm(self, algorithm, dtype): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, @@ -1130,9 +1126,6 @@ def testDotAlgorithm(self, algorithm, dtype): self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): raise SkipTest("Not supported on CPU.") lhs_shape = (3, 4) @@ -1143,9 +1136,6 @@ def testDotAlgorithmInvalidFloat8Type(self): lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32") def testDotAlgorithmCasting(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["tpu"]): raise SkipTest("F32_F32_F32 is not supported on TPU.") def fun(lhs, rhs):