Skip to content

Commit

Permalink
Fix handling of integer types during gradient computation with mhlo m…
Browse files Browse the repository at this point in the history
…odules.

PiperOrigin-RevId: 693002319
  • Loading branch information
shaobohou authored and TF2JAXDev committed Nov 4, 2024
1 parent 1194d89 commit 171d7e3
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 12 deletions.
30 changes: 30 additions & 0 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,36 @@ def grad(dy):
self.assertAllClose(expected_outputs, re_jax_outputs)
self.assertAllClose(expected_grads, re_jax_grads)

# Jax -> TF -> SavedModel -> TF -> Jax -> TF
tf_forward2 = jax2tf.convert(re_jax_forward, with_gradient=with_grad)
tf_forward2 = tf.function(tf_forward2, autograph=False)
concrete_tf_forward2 = tf_forward2.get_concrete_function(
tf.TensorSpec(shape=(3, 2))
)
tf_outputs2 = concrete_tf_forward2(inputs)
self.assertAllClose(tf_outputs2, re_jax_outputs)

# Jax -> TF -> SavedModel -> TF -> Jax -> TF -> SavedModel
model = tf.Module()
model.f = tf_forward2
tmp_dir = self.create_tempdir()
tf.saved_model.save(model, tmp_dir.full_path)
del model
restored2 = tf.saved_model.load(tmp_dir.full_path)
new_tf_outputs = restored2.f(inputs)
self.assertAllClose(new_tf_outputs, tf_outputs2)

# Jax -> TF -> SavedModel -> TF -> Jax -> TF -> SavedModel -> Jax
with config.override_config("convert_custom_gradient", True):
re_jax_forward2 = tf2jax.convert_functional(
restored2.f, tf.zeros_like(inputs)
)
re_jax_forward2 = self.variant(re_jax_forward2)
re_jax_outputs2 = re_jax_forward2(inputs)
re_jax_grads2 = jax.grad(re_jax_forward2)(inputs)
self.assertAllClose(expected_outputs, re_jax_outputs2)
self.assertAllClose(expected_grads, re_jax_grads2)

@chex.variants(with_jit=True)
def test_custom_gradient_saved_model(self):
model = tf.saved_model.load(
Expand Down
9 changes: 8 additions & 1 deletion tf2jax/_src/tf2jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,14 @@ def clear_shape_env():

# Alternatively, try running get_concrete_function in a separate thread?
with inside_call_tf(), clear_shape_env():
concrete_tf_grad_fn = tf_grad_fn.get_concrete_function(*input_specs)
try:
concrete_tf_grad_fn = tf_grad_fn.get_concrete_function(*input_specs)
except NotImplementedError as e:
logging.info(
"Failed to get concrete function for %s: %s", grad_fn_name, e
)
library[grad_fn_name] = None
return

logging.info("Converting gradient function %s", grad_fn_name)
grad_inputs = concrete_tf_grad_fn.inputs
Expand Down
10 changes: 6 additions & 4 deletions tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
mhlo_apply_p.multiple_results = True


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class MhloModule:
module: str # string representation of the MLIR module.
fun_name: str
assume_grad_fn: bool = False
require_platform_index: bool = False

def __str__(self):
return f"MhloModule(fun_name={self.fun_name}, ...)"
Expand All @@ -62,7 +63,7 @@ def mhlo_apply_impl(*args, module: MhloModule):

# See https://github.com/google/jax/blob/main/jax/_src/interpreters/mlir.py#L115
# for reference
def _ir_type_to_dtype(ir_type: ir.Type) -> jnp.dtype:
def ir_type_to_dtype(ir_type: ir.Type) -> jnp.dtype:
"""Converts MLIR type to JAX dtype."""
ir_to_jax = {
ir.IntegerType.get_signless(1): jnp.bool_,
Expand Down Expand Up @@ -135,15 +136,16 @@ def mhlo_apply_abstract_eval(
assert has_polymorphic, has_polymorphic
if module.assume_grad_fn:
# TODO(b/329832868) Fix this properly.
out_shape = in_avals[idx].shape
offset = 1 if module.require_platform_index else 0
out_shape = in_avals[idx + offset].shape
else:
out_shape = export.symbolic_shape(
out_shape, like=res.shape, scope=symbolic_scope
)
else:
out_shape = res.shape
output_specs.append(
core.ShapedArray(out_shape, _ir_type_to_dtype(res.element_type))
core.ShapedArray(out_shape, ir_type_to_dtype(res.element_type))
)
return tuple(output_specs)

Expand Down
44 changes: 37 additions & 7 deletions tf2jax/experimental/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _platform_to_alias(platform: str) -> str:
@functools.lru_cache(None)
def _refine_with_static_input_shapes(
module_text: str, operands: Tuple[jax.core.ShapedArray, ...]
) -> str:
) -> tuple[str, list[jax.core.ShapedArray]]:
"""Refine the polymorphic shapes inside a module."""
# Wrap original main within another function with static input shapes.
context = mlir.make_ir_context()
Expand All @@ -57,8 +57,25 @@ def _refine_with_static_input_shapes(
symbol_table.set_symbol_name(orig_main, "_orig_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value

# Use static shapes
# This help refine polymorphic shapes.
new_main_input_types = [mlir.aval_to_ir_type(x) for x in operands]
# Retain the original element type. This is necessary because
# jax.custom_gradient will replace integer types with the corresponding
# tangent types, i.e. float0.
for idx, (ox, nx) in enumerate(
zip(orig_main.type.inputs, new_main_input_types, strict=True)
):
assert isinstance(nx, ir.RankedTensorType), nx
if ox.element_type != nx.element_type:
new_main_input_types[idx] = ir.RankedTensorType.get(
nx.shape, ox.element_type
)
# Final input specs to be returned.
input_specs = [
jax.core.ShapedArray(x.shape, mhlo.ir_type_to_dtype(x.element_type))
for x in new_main_input_types
]

orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
Expand Down Expand Up @@ -109,7 +126,7 @@ def _refine_with_static_input_shapes(
module,
validate_static_shapes=all([isinstance(x, int) for x in input_dims]),
)
return mlir.module_to_string(module)
return mlir.module_to_string(module), input_specs


@ops.register_operation("XlaCallModule")
Expand Down Expand Up @@ -141,8 +158,9 @@ def _xla_call_module(proto):

dim_args_spec = tuple(proto.attr["dim_args_spec"].list.s)
if dim_args_spec:
raise ValueError("Dynamic shapes is not yet supported, found "
f"dim_args_spec={dim_args_spec}.")
raise ValueError(
"Dynamic shapes is not yet supported, found "
f"dim_args_spec={dim_args_spec}.")

function_list = tuple(proto.attr["function_list"].list.func)
if function_list:
Expand Down Expand Up @@ -217,18 +235,30 @@ def check_platforms():

def _func(*operands: jnp.ndarray) -> Tuple[jnp.ndarray, ...]:
platform_index = check_platforms()
if platform_index is not None and len(target_platforms) > 1:
require_platform_index = (
platform_index is not None and len(target_platforms) > 1
)
if require_platform_index:
operands = (jnp.array(platform_index),) + operands

refined_mhlo_text = _refine_with_static_input_shapes(
refined_mhlo_text, input_specs = _refine_with_static_input_shapes(
mhlo_text,
tuple(jax.core.ShapedArray(x.shape, x.dtype) for x in operands),
)
mhlo_module = mhlo.MhloModule(
module=refined_mhlo_text,
fun_name=proto.name,
assume_grad_fn=assume_grad_fn,
require_platform_index=require_platform_index,
)
if assume_grad_fn:
# The change in _refine_with_static_input_shapes is not enough as
# depending on whether we are computing gradient via Jax or TF, integer
# types may or may not be replaced with float0.
operands = [
jnp.zeros(x.shape, y.dtype) if x.dtype == jax.dtypes.float0 else x
for x, y in zip(operands, input_specs, strict=True)
]
return mhlo.mhlo_apply(*operands, module=mhlo_module)

return _func

0 comments on commit 171d7e3

Please sign in to comment.