From 4688da31183d2270c5059af46b5065c6e0a1d077 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Oct 2024 16:53:03 -0700 Subject: [PATCH] Fix jax2tf failure coming from dot_general PiperOrigin-RevId: 688738110 --- jax/experimental/jax2tf/impl_no_xla.py | 1 + jax/experimental/jax2tf/jax2tf.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 310cbaab6d59..0d8c95d42676 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,6 +364,7 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + out_type=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index dcf9cafb5117..29a1034e51ed 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2180,9 +2180,10 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated -def _dot_general(lhs, rhs, *, dimension_numbers, out_type, +def _dot_general(lhs, rhs, *, dimension_numbers, precision: lax_internal.CanonicalPrecision, preferred_element_type: DType | None, + out_type=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""