diff --git a/tf2jax/experimental/mhlo.py b/tf2jax/experimental/mhlo.py index 4dc7664..ca3895a 100644 --- a/tf2jax/experimental/mhlo.py +++ b/tf2jax/experimental/mhlo.py @@ -127,12 +127,16 @@ def mhlo_apply_abstract_eval( assert has_polymorphic, has_polymorphic if jax.__version_info__ <= (0, 4, 14): from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error - else: + out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access # pytype: disable=module-attr + elif jax.__version_info__ <= (0, 4, 20): + from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error + out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access # pytype: disable=module-attr + elif jax.__version_info__ <= (0, 4, 23): from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error - if jax.__version_info__ <= (0, 4, 20): - out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access - else: out_shape = shape_poly.symbolic_shape(out_shape, like=res.shape) + else: + from jax.experimental import export # pylint: disable=g-import-not-at-top # pytype: disable=import-error + out_shape = export.symbolic_shape(out_shape, like=res.shape) else: out_shape = res.shape output_specs.append(