Skip to content

Commit

Permalink
Run dot algorithm tests with PJRT plugin.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 31, 2024
1 parent c708a04 commit 52ad605
Showing 1 changed file with 0 additions and 10 deletions.
10 changes: 0 additions & 10 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from jax._src import dtypes
from jax._src import lax_reference
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
Expand Down Expand Up @@ -1077,9 +1076,6 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision):
if jtu.dtypes.supported([dtype])
])
def testDotAlgorithm(self, algorithm, dtype):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jtu.test_device_matches(["cpu"]):
if algorithm not in {
lax.DotAlgorithmPreset.DEFAULT,
Expand Down Expand Up @@ -1130,9 +1126,6 @@ def testDotAlgorithm(self, algorithm, dtype):
self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype)

def testDotAlgorithmInvalidFloat8Type(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jtu.test_device_matches(["cpu"]):
raise SkipTest("Not supported on CPU.")
lhs_shape = (3, 4)
Expand All @@ -1143,9 +1136,6 @@ def testDotAlgorithmInvalidFloat8Type(self):
lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32")

def testDotAlgorithmCasting(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jtu.test_device_matches(["tpu"]):
raise SkipTest("F32_F32_F32 is not supported on TPU.")
def fun(lhs, rhs):
Expand Down

0 comments on commit 52ad605

Please sign in to comment.