Skip to content

Commit

Permalink
Fix jax2tf failure coming from dot_general
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688738110
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 22, 2024
1 parent f8a1f02 commit 4688da3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/impl_no_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def _conv_general_dilated(
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
preferred_element_type: DType | None,
out_type=None,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,9 +2180,10 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None):
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated


def _dot_general(lhs, rhs, *, dimension_numbers, out_type,
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: lax_internal.CanonicalPrecision,
preferred_element_type: DType | None,
out_type=None,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
Expand Down

0 comments on commit 4688da3

Please sign in to comment.