Skip to content

Commit

Permalink
Replace shape_poly._parse_spec with shape_poly.symbolic_shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 585974174
  • Loading branch information
shaobohou authored and TF2JAXDev committed Nov 28, 2023
1 parent b5c2fd5 commit e228b5b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def mhlo_apply_abstract_eval(
from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
else:
from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access
if jax.__version_info__ <= (0, 4, 20):
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access
else:
out_shape = shape_poly.symbolic_shape(out_shape, like=res.shape)
else:
out_shape = res.shape
output_specs.append(
Expand Down

0 comments on commit e228b5b

Please sign in to comment.