From dedfa0d54452936bdb0f19881f8270cc701d3045 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 13 Jan 2025 15:17:36 -0800 Subject: [PATCH] Fix remat bug on primitives with multiple outputs. Addresses https://github.com/jax-ml/jax/issues/25841 PiperOrigin-RevId: 715125226 --- jax/_src/interpreters/partial_eval.py | 6 ++++-- tests/memories_test.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ac0ae3a13967..587b534c242d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1080,7 +1080,8 @@ def has_effects(effects) -> bool: outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None], + dict(devices=[TransferToMemoryKind(policy.dst) + ] * len(outvars_copy), srcs=[None], copy_semantics=[CopySemantics.COPY]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) @@ -1090,7 +1091,8 @@ def has_effects(effects) -> bool: residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None], + dict(devices=[TransferToMemoryKind(policy.src) + ] * len(resvars), srcs=[None], copy_semantics=[CopySemantics.COPY]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) diff --git a/tests/memories_test.py b/tests/memories_test.py index 12353d49a228..25c0483c8d29 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -27,8 +27,10 @@ from jax._src import xla_bridge as xb from jax._src.layout import DeviceLocalLayout as DLL, Layout from jax._src import config +from jax._src import core from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp +from jax._src.interpreters import ad from jax.ad_checkpoint import Offloadable, remat, Recompute from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import (NamedSharding, PositionalSharding, @@ -1842,5 +1844,25 @@ def f(x): if compiled_stats is not None: self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + def test_primitive_with_multiple_outputs(self): + # Test for https://github.com/jax-ml/jax/issues/25841 + shape = (128,) + inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + + def policy(prim, *args, **kwargs): + del args, kwargs + if prim.multiple_results: + return Offloadable("device", "pinned_host") + return Recompute + + @functools.partial(remat, policy=policy) + def test_fn(x): + # Need any primitive with multiple outputs and a non-trivial grad. + x1, _ = jax.lax.approx_max_k(x, k=2) + return jnp.sum(x1) + + fn = jax.grad(test_fn) + jax.jit(fn)(inp) # doesn't crash + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())