From 52ad60521cffaa4d175cb5dcc8c8a8c5bc3d6738 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 23 Oct 2024 09:17:27 -0400 Subject: [PATCH] Run dot algorithm tests with PJRT plugin. --- tests/lax_test.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 79ad9fcfa02c..f2ce0913e03a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -41,7 +41,6 @@ from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util @@ -1077,9 +1076,6 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): if jtu.dtypes.supported([dtype]) ]) def testDotAlgorithm(self, algorithm, dtype): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, @@ -1130,9 +1126,6 @@ def testDotAlgorithm(self, algorithm, dtype): self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): raise SkipTest("Not supported on CPU.") lhs_shape = (3, 4) @@ -1143,9 +1136,6 @@ def testDotAlgorithmInvalidFloat8Type(self): lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32") def testDotAlgorithmCasting(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["tpu"]): raise SkipTest("F32_F32_F32 is not supported on TPU.") def fun(lhs, rhs):