diff --git a/jax/BUILD b/jax/BUILD index 271f5e95a626..175f0f6fe1e8 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -341,6 +341,7 @@ pytype_strict_library( ":config", ":core", ":dtypes", + ":state_types", ":traceback_util", ":tree_util", ":util", diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 97052c955c69..b2bed4d4b236 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -23,7 +23,9 @@ import numpy as np from jax._src import core +from jax._src import config from jax._src import dtypes +from jax._src.state.types import AbstractRef from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ShapedArray from jax._src.tree_util import ( @@ -737,3 +739,31 @@ def __eq__(self, other): def register_class_with_attrs(t: type) -> None: _class_with_attrs.add(t) _class_with_attrs: set[type] = set() + +# TODO(mattjj): make this function faster +def _check_no_aliased_ref_args(dbg, avals, args): + assert config.mutable_array_checks.value + refs: dict[int, int] = {} + for i, (a, x) in enumerate(zip(avals, args)): + if (isinstance(a, AbstractRef) and + (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): + raise ValueError( + "only one reference to a mutable array may be passed as an argument " + f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " + f"the mutable array reference of type {a.str_short()} appeared at both " + f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}." + if dbg else + f"at both flat index {dup_idx} and flat index {i}") from None + +def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None: + assert config.mutable_array_checks.value + refs: set[int] = {id(core.get_referent(c)) for c in consts + if isinstance(core.get_aval(c), AbstractRef)} + for i, x in enumerate(args): + if id(core.get_referent(x)) in refs: + a = shaped_abstractify(x) + raise ValueError( + f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable " + f"array reference of type {a.str_short()} was both closed over and " + f"passed as the argument " + f"{dbg.arg_names[i]}" if dbg else "at flat index {i}") diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 1e98c19476c2..4a68e7b8ec43 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -34,7 +34,9 @@ from jax._src import source_info_util from jax._src import state from jax._src import util -from jax._src.api_util import shaped_abstractify +from jax._src.api_util import ( + shaped_abstractify, _check_no_aliased_ref_args, + _check_no_aliased_closed_over_refs) from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -271,13 +273,20 @@ def scan(f, init, xs, length=None): xs_avals = [core.get_aval(x) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] + if config.mutable_array_checks.value: + in_flat, in_tree = tree_flatten((init, xs)) + dbg = pe.debug_info(f, in_tree, None, False, 'scan') + in_avals = tuple(_map(core.get_aval, in_flat)) + _check_no_aliased_ref_args(dbg, in_avals, in_flat) + def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) - carry_avals = tuple(_map(core.get_aval, init_flat)) jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs( f, in_tree, (*carry_avals, *x_avals), "scan") + if config.mutable_array_checks.value: + _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat) out_tree_children = out_tree.children() if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ed91e76fa38a..1641493b64a5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -49,7 +49,8 @@ argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, shaped_abstractify, check_callable, resolve_argnums, argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, - hoist_obj_attrs) + hoist_obj_attrs, _check_no_aliased_ref_args, + _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec from jax._src.interpreters import xla @@ -627,7 +628,8 @@ def _infer_params_impl( flat_fun, in_type, attr_token, dbg, HashableFunction(res_paths, closure=()), IgnoreKey(ji.inline)) - _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) + if config.mutable_array_checks.value: + _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) _attr_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( @@ -764,33 +766,9 @@ def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...] " static_argnums or static_argnames parameters of jax.jit." ) from None if config.mutable_array_checks.value: - # TODO(mattjj): make this faster - refs: dict[int, int] = {} - for i, (a, x) in enumerate(zip(avals, explicit_args)): - if (isinstance(a, AbstractRef) and - (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): - raise ValueError( - "only one reference to a mutable array may be passed as an argument " - f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " - f"the mutable array reference of type {a.str_short()} appeared at both " - f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}." - if dbg else - f"at both flat index {dup_idx} and flat index {i}") from None + _check_no_aliased_ref_args(dbg, avals, explicit_args) return tuple(avals) -def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None: - if not config.mutable_array_checks.value: return - refs: set[int] = {id(core.get_referent(c)) for c in consts - if isinstance(core.get_aval(c), AbstractRef)} - for i, x in enumerate(args): - if id(core.get_referent(x)) in refs: - a = shaped_abstractify(x) - raise ValueError( - f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable " - f"array reference of type {a.str_short()} was both closed over and " - f"passed as the argument " - f"{dbg.arg_names[i]}" if dbg else "at flat index {i}") - def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], explicit_args: Sequence[Any] diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 14f08d2dcdaa..ce355f86ab67 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -206,7 +206,6 @@ def test_scan_scanned_mut_array(self, jit): def body_fun(_, index_x): (index, x) = index_x x[...] += index - # breakpoint() return ((), x[...]) x_mut = core.mutable_array(np.arange(5)) @@ -289,8 +288,18 @@ def test_return_from_scan(self): ValueError, "traced for scan returned a mutable array reference of type"): jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3)) - # TODO test_argument_aliases_scan - # TODO test_closure_and_argument_aliases_scan + def test_argument_aliases_scan(self): + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( + ValueError, r"appeared at both c\[0\] and c\[1\]"): + jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1) + + def test_closure_and_argument_aliases_scan(self): + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( + ValueError, r"closed over and passed as the argument y_ref"): + jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref, + None, length=1) def test_return_from_cond(self): with self.assertRaisesRegex(