diff --git a/jax/_src/config.py b/jax/_src/config.py index 2e71a89ea2b1..aafaf0ce0df5 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1479,6 +1479,12 @@ def _update_disable_jit_thread_local(val): upgrade=True, help='Disable the check from #19009 to enable some custom_vjp hacks.') +mutable_array_checks = bool_state( + name='jax_mutable_array_checks', + default=False, + upgrade=True, + help='Enable error checks for mutable arrays that rule out aliasing.') + xla_runtime_errors = bool_state( name='jax_experimental_unsafe_xla_runtime_errors', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index 59783466b151..c75206894929 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1917,6 +1917,8 @@ def mutable_array_abstract_eval(init_aval): def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # pytype: disable=import-error aval = get_aval(init_val) + # TODO(mattjj): improve spelling of 'defensive copy' here, avoid circular dep + init_val = init_val.copy() if hasattr(init_val, 'copy') else init_val return MutableArray(AbstractRef(aval), init_val) def freeze(ref): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a5984294ee8a..f188687f16f0 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -986,21 +986,11 @@ def partial_eval_jaxpr_custom( ensure_out_inst: bool | Sequence[bool], saveable: Callable[..., RematCases_], ) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]: - if type(in_inst) is bool: - in_inst = (in_inst,) * len(jaxpr.invars) - if type(ensure_out_unknowns) is bool: - ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) - if type(ensure_out_inst) is bool: - ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) - jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ - _partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns), - tuple(in_inst), - tuple(ensure_out_unknowns), - tuple(ensure_out_inst), saveable) - if num_res_ref > 0: - raise ValueError( - "Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.") - return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res + *outs, num_res_ref = partial_eval_jaxpr_stateful( + jaxpr, in_unknowns, in_inst, ensure_out_unknowns, ensure_out_inst, saveable) + if num_res_ref: + raise ValueError("Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.") + return *outs, # type: ignore def partial_eval_jaxpr_stateful( jaxpr: Jaxpr, @@ -1019,10 +1009,9 @@ def partial_eval_jaxpr_stateful( if saveable is None: saveable = everything_saveable jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ - _partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns), - tuple(in_inst), - tuple(ensure_out_unknowns), - tuple(ensure_out_inst), saveable) + _partial_eval_jaxpr_custom_cached( + jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns), + tuple(ensure_out_inst), saveable) return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref everything_saveable = lambda *_, **__: True @@ -2165,12 +2154,45 @@ def trace_to_jaxpr_dynamic( ans = fun.call_wrapped(*in_tracers) out_tracers = map(trace.to_jaxpr_tracer, ans) + _check_no_refs(debug_info, out_tracers) jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers) del trace, fun, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked +def _check_no_refs( + dbg: lu.TracingDebugInfo | None, + out_tracers: Sequence[DynamicJaxprTracer] +) -> None: + if not config.mutable_array_checks.value: return + for i, t in enumerate(out_tracers): + a = t.aval + if isinstance(a, AbstractRef): + if dbg is None: + raise ValueError( + f"function returned a mutable array reference of type {a.str_short()}, " + "but mutable array references cannot be returned.") + loc = (f' at output tree path {keystr(ls[i])}' # type: ignore + if dbg.result_paths and (ls := dbg.result_paths()) and ls[i] else '') + frame = t._trace.frame + v = frame.tracer_to_var.get(id(t)) + eqn = next((e for e in frame.eqns if v in e.outvars), None) + if eqn: + assert eqn.primitive is core.mutable_array_p + origin_info = ('\n\nThe returned mutable array was created on line ' + f'{source_info_util.summarize(eqn.source_info)}.') + elif v in frame.invars: + arg_name = dbg.arg_names[frame.invars.index(v)] + origin_info = ('\n\nThe returned mutable array was passed in as the ' + f'argument {arg_name}.') + else: + origin_info = '' + raise ValueError( + f"function {dbg.func_src_info} traced for {dbg.traced_for} returned " + f"a mutable array reference of type {a.str_short()}{loc}, but " + f"mutable array references cannot be returned.{origin_info}") + @profiler.annotate_function def trace_to_jaxpr_dynamic2( fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index aafaaad68269..ed91e76fa38a 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -556,17 +556,14 @@ def _infer_params_impl( "pjit does not support kwargs when in_shardings is specified.") if pjit_mesh is not None: - jit_name = 'pjit' if (ji.backend or ji.device) and not pjit_mesh.empty: raise ValueError( "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit.") - else: - jit_name = 'jit' axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs, + dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, ji.static_argnums, ji.static_argnames) f = lu.wrap_init(fun) f, res_paths = result_paths(f) @@ -593,6 +590,7 @@ def _infer_params_impl( in_shardings_leaves = out_shardings_leaves = tuple(leaves) in_shardings_treedef = out_shardings_treedef = treedef else: + jit_name = 'pjit' if pjit_mesh is not None else 'jit' in_shardings_leaves = tuple( _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name) for x in ji.in_shardings_leaves) @@ -607,35 +605,12 @@ def _infer_params_impl( in_type: core.InputType | tuple[core.AbstractValue, ...] if config.dynamic_shapes.value: + assert in_avals is None in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_avals = tuple(a for a, e in in_type if e) - elif in_avals is None: - avals = [] - for i, a in enumerate(explicit_args): - try: - avals.append(shaped_abstractify(a)) - except OverflowError as e: - arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg - else f"flattened argument number is {i}") - raise OverflowError( - "An overflow was encountered while parsing an argument to a jitted " - f"computation, whose {arg_path}." - ) from e - except TypeError as e: - arg_description = (f"path {dbg.arg_names[i]}" if dbg - else f"flattened argument number {i}") - raise TypeError( - f"Error interpreting argument to {fun} as an abstract array." - f" The problematic value is of type {type(a)} and was passed to" - f" the function at {arg_description}.\n" - "This typically means that a jit-wrapped function was called with a non-array" - " argument, and this argument was not marked as static using the" - " static_argnums or static_argnames parameters of jax.jit." - ) from e - - in_type = in_avals = tuple(avals) else: - in_type = in_avals + in_type = in_avals # type: ignore + assert in_avals is not None in_shardings_flat, in_layouts_flat = _process_in_axis_resources( in_shardings_treedef, in_shardings_leaves, @@ -652,6 +627,7 @@ def _infer_params_impl( flat_fun, in_type, attr_token, dbg, HashableFunction(res_paths, closure=()), IgnoreKey(ji.inline)) + _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) _attr_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( @@ -693,7 +669,6 @@ def _infer_params_impl( donated_invars, dbg.arg_names if dbg else None, len(consts), attrs_tracked, abstract_mesh), args_flat - def get_abstract_mesh_from_avals(in_avals): if not config.sharding_in_types.value: return None @@ -711,9 +686,7 @@ def get_abstract_mesh_from_avals(in_avals): class InferParamsCacheEntry: """Mutable value object for _infer_params_cached.""" __slots__ = ['pjit_params'] - pjit_params: PjitParams | None - def __init__(self): self.pjit_params = None @@ -747,34 +720,76 @@ def _infer_params( resource_env = None pjit_mesh = None - skip_cache = config.dynamic_shapes.value - if not skip_cache: - signature, dynargs = jax_jit.parse_arguments( - args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, - ji.static_argnames, tree_util.default_registry) - try: - avals = tuple(shaped_abstractify(a) for a in dynargs) - except (OverflowError, TypeError): - # If we see something we don't understand, use the slow path. - skip_cache = True - - if skip_cache: + if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=None) return p, p.consts + args_flat - entry = _infer_params_cached( - fun, ji, signature, avals, pjit_mesh, resource_env) + signature, dynargs = jax_jit.parse_arguments( + args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, + ji.static_argnames, tree_util.default_registry) + dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, + ji.static_argnums, ji.static_argnames) + avals = _infer_input_type(fun, dbg, dynargs) + entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env) if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals) - if p.attrs_tracked: - # If there are attrs_tracked, don't use the cache. + if p.attrs_tracked: # if attrs, don't popoulate the cache return p, p.consts + args_flat - else: - entry.pjit_params = p + entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs +def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]: + avals = [] + try: + for i, x in enumerate(explicit_args): + avals.append(shaped_abstractify(x)) + except OverflowError: + arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg # type: ignore + else f"flattened argument number is {i}") # type: ignore + raise OverflowError( + "An overflow was encountered while parsing an argument to a jitted " + f"computation, whose {arg_path}." + ) from None + except TypeError: + arg_description = (f"path {dbg.arg_names[i]}" if dbg # type: ignore + else f"flattened argument number {i}") # type: ignore + raise TypeError( + f"Error interpreting argument to {fun} as an abstract array." + f" The problematic value is of type {type(x)} and was passed to" # type: ignore + f" the function at {arg_description}.\n" + "This typically means that a jit-wrapped function was called with a non-array" + " argument, and this argument was not marked as static using the" + " static_argnums or static_argnames parameters of jax.jit." + ) from None + if config.mutable_array_checks.value: + # TODO(mattjj): make this faster + refs: dict[int, int] = {} + for i, (a, x) in enumerate(zip(avals, explicit_args)): + if (isinstance(a, AbstractRef) and + (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): + raise ValueError( + "only one reference to a mutable array may be passed as an argument " + f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " + f"the mutable array reference of type {a.str_short()} appeared at both " + f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}." + if dbg else + f"at both flat index {dup_idx} and flat index {i}") from None + return tuple(avals) + +def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None: + if not config.mutable_array_checks.value: return + refs: set[int] = {id(core.get_referent(c)) for c in consts + if isinstance(core.get_aval(c), AbstractRef)} + for i, x in enumerate(args): + if id(core.get_referent(x)) in refs: + a = shaped_abstractify(x) + raise ValueError( + f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable " + f"array reference of type {a.str_short()} was both closed over and " + f"passed as the argument " + f"{dbg.arg_names[i]}" if dbg else "at flat index {i}") def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 4eb354a8d50f..2334a7b98f91 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -307,6 +307,7 @@ def body(x, __): self.assertAllClose(thing.x, 1024., check_dtypes=False) def test_arg_to_jit(self): + self.skipTest("regressed this experimental feature") # TODO(mattjj) thing = Thing(1.0) count = 0 diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index f1b80f32446a..14f08d2dcdaa 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -48,13 +48,6 @@ def f(x_mut): jaxpr = jax.make_jaxpr(f)(x_mut) self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) - # disabling this test for now. TODO(dougalm): re-enable once we add checks to - # ensure mutable arrays aren't returned or duplicated etc. - # def test_staging_error(self): - # x = jnp.zeros(3) - # with self.assertRaises(Exception): - # jax.jit(core.mutable_array)(x) - @parameterized.parameters([True, False]) def test_multiple_inputs_and_outputs(self, jit): def f(x_mut, y, z_mut, w): @@ -244,6 +237,91 @@ def f(x): expected = 2.0 self.assertAllClose(ans, expected, check_dtypes=False) + def test_defensive_copy(self): + x = jnp.arange(3.) + _ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x)) + x + 1 # don't crash + + +@jtu.with_config(jax_mutable_array_checks=True) +class MutableArrayErrorsTest(jtu.JaxTestCase): + def test_return_from_jit(self): + with self.assertRaisesRegex( + ValueError, + r"traced for jit returned a mutable array reference.*\n\n" + r".*was created on line"): + jax.jit(core.mutable_array)(jnp.arange(3)) + + def test_return_from_jit_arg(self): + with self.assertRaisesRegex( + ValueError, + r"traced for jit returned a mutable array reference.*\n\n" + r".*was passed in as the argument x_ref"): + jax.jit(lambda x_ref: x_ref)(core.mutable_array(jnp.arange(3))) + + def test_return_from_jit_pytree(self): + with self.assertRaisesRegex( + ValueError, + r"tree path \['hi'\]"): + jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3))) + + def test_return_from_jit_closure(self): + with self.assertRaisesRegex( + ValueError, + r"tree path \['hi'\]"): + x_ref = core.mutable_array(jnp.arange(3)) + jax.jit(lambda: {'hi': x_ref})() + + def test_argument_aliases_jit(self): + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( + ValueError, "appeared at both x_ref and y_ref"): + jax.jit(lambda x_ref, y_ref: x_ref[...] + y_ref[...])(x_ref, x_ref) + + def test_closure_and_argument_aliases_jit(self): + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( + ValueError, "closed over and passed as the argument y_ref"): + jax.jit(lambda y_ref: x_ref[...] + y_ref[...])(x_ref) + + def test_return_from_scan(self): + with self.assertRaisesRegex( + ValueError, "traced for scan returned a mutable array reference of type"): + jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3)) + + # TODO test_argument_aliases_scan + # TODO test_closure_and_argument_aliases_scan + + def test_return_from_cond(self): + with self.assertRaisesRegex( + ValueError, "traced for cond returned a mutable array reference of type"): + jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0)) + + # TODO test_argument_aliases_cond + # TODO test_closure_and_argument_aliases_cond + + # TODO test_return_from_custom_jvp/vjp + # TODO test_argument_aliases_custom_jvp/vjp + # TODO test_closure_and_argument_aliases_custom_jvp/vjp + + # TODO(mattjj): enable when cond works with mutable arrays + # @parameterized.parameters([False, True]) + # def test_cond_both_branches_close_over_same_mutable_array(self, jit): + # # see also test_cond_with_ref_reuse in state_test.py + # x_ref = core.mutable_array(0.) + # def f(pred): + # def true_fun(): + # x_ref[()] = 1. + # def false_fun(): + # x_ref[()] = 2. + # jax.lax.cond(pred, true_fun, false_fun) + # if jit: + # f = jax.jit(f) + # out_true = f(True) + # self.assertAllClose(x_ref[...], 1.) + # out_false = f(False) + # self.assertAllClose(x_ref[...], 2.) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())