Skip to content

Commit

Permalink
Fix remat bug on primitives with multiple outputs.
Browse files Browse the repository at this point in the history
Addresses #25841

PiperOrigin-RevId: 715125226
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 14, 2025
1 parent 9ba1fd2 commit dedfa0d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,8 @@ def has_effects(effects) -> bool:
outvars_copy = list[Atom](eqn.outvars)
offload_eqn = core.JaxprEqn(
outvars_copy, resvars, device_put_p,
dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None],
dict(devices=[TransferToMemoryKind(policy.dst)
] * len(outvars_copy), srcs=[None],
copy_semantics=[CopySemantics.COPY]),
set(), source_info_util.new_source_info(),
JaxprEqnContext(None, False))
Expand All @@ -1090,7 +1091,8 @@ def has_effects(effects) -> bool:
residuals.update(resvars)
reload_eqn = core.JaxprEqn(
resvars, eqn.outvars, device_put_p,
dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None],
dict(devices=[TransferToMemoryKind(policy.src)
] * len(resvars), srcs=[None],
copy_semantics=[CopySemantics.COPY]),
set(), source_info_util.new_source_info(),
JaxprEqnContext(None, False))
Expand Down
22 changes: 22 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
from jax._src import xla_bridge as xb
from jax._src.layout import DeviceLocalLayout as DLL, Layout
from jax._src import config
from jax._src import core
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
from jax._src.interpreters import ad
from jax.ad_checkpoint import Offloadable, remat, Recompute
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
Expand Down Expand Up @@ -1842,5 +1844,25 @@ def f(x):
if compiled_stats is not None:
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_primitive_with_multiple_outputs(self):
# Test for https://github.com/jax-ml/jax/issues/25841
shape = (128,)
inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)

def policy(prim, *args, **kwargs):
del args, kwargs
if prim.multiple_results:
return Offloadable("device", "pinned_host")
return Recompute

@functools.partial(remat, policy=policy)
def test_fn(x):
# Need any primitive with multiple outputs and a non-trivial grad.
x1, _ = jax.lax.approx_max_k(x, k=2)
return jnp.sum(x1)

fn = jax.grad(test_fn)
jax.jit(fn)(inp) # doesn't crash

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit dedfa0d

Please sign in to comment.