Skip to content

Commit

Permalink
Merge pull request #25449 from mattjj:ref-errors-3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707449170
  • Loading branch information
Google-ML-Automation committed Dec 18, 2024
2 parents dc0b774 + 42ac4ca commit 96d4a75
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 77 deletions.
6 changes: 6 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,12 @@ def _update_disable_jit_thread_local(val):
upgrade=True,
help='Disable the check from #19009 to enable some custom_vjp hacks.')

mutable_array_checks = bool_state(
name='jax_mutable_array_checks',
default=False,
upgrade=True,
help='Enable error checks for mutable arrays that rule out aliasing.')

xla_runtime_errors = bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,6 +1917,8 @@ def mutable_array_abstract_eval(init_aval):
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
aval = get_aval(init_val)
# TODO(mattjj): improve spelling of 'defensive copy' here, avoid circular dep
init_val = init_val.copy() if hasattr(init_val, 'copy') else init_val
return MutableArray(AbstractRef(aval), init_val)

def freeze(ref):
Expand Down
60 changes: 41 additions & 19 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,21 +986,11 @@ def partial_eval_jaxpr_custom(
ensure_out_inst: bool | Sequence[bool],
saveable: Callable[..., RematCases_],
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
if type(ensure_out_unknowns) is bool:
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
if num_res_ref > 0:
raise ValueError(
"Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res
*outs, num_res_ref = partial_eval_jaxpr_stateful(
jaxpr, in_unknowns, in_inst, ensure_out_unknowns, ensure_out_inst, saveable)
if num_res_ref:
raise ValueError("Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
return *outs, # type: ignore

def partial_eval_jaxpr_stateful(
jaxpr: Jaxpr,
Expand All @@ -1019,10 +1009,9 @@ def partial_eval_jaxpr_stateful(
if saveable is None:
saveable = everything_saveable
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
_partial_eval_jaxpr_custom_cached(
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref

everything_saveable = lambda *_, **__: True
Expand Down Expand Up @@ -2165,12 +2154,45 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)

out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_refs(debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
del trace, fun, in_tracers, out_tracers, ans

config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked

def _check_no_refs(
dbg: lu.TracingDebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:
if not config.mutable_array_checks.value: return
for i, t in enumerate(out_tracers):
a = t.aval
if isinstance(a, AbstractRef):
if dbg is None:
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
loc = (f' at output tree path {keystr(ls[i])}' # type: ignore
if dbg.result_paths and (ls := dbg.result_paths()) and ls[i] else '')
frame = t._trace.frame
v = frame.tracer_to_var.get(id(t))
eqn = next((e for e in frame.eqns if v in e.outvars), None)
if eqn:
assert eqn.primitive is core.mutable_array_p
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.arg_names[frame.invars.index(v)]
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
origin_info = ''
raise ValueError(
f"function {dbg.func_src_info} traced for {dbg.traced_for} returned "
f"a mutable array reference of type {a.str_short()}{loc}, but "
f"mutable array references cannot be returned.{origin_info}")

@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
Expand Down
117 changes: 66 additions & 51 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,17 +556,14 @@ def _infer_params_impl(
"pjit does not support kwargs when in_shardings is specified.")

if pjit_mesh is not None:
jit_name = 'pjit'
if (ji.backend or ji.device) and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
else:
jit_name = 'jit'

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)

dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
Expand All @@ -593,6 +590,7 @@ def _infer_params_impl(
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
in_shardings_treedef = out_shardings_treedef = treedef
else:
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
in_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
for x in ji.in_shardings_leaves)
Expand All @@ -607,35 +605,12 @@ def _infer_params_impl(

in_type: core.InputType | tuple[core.AbstractValue, ...]
if config.dynamic_shapes.value:
assert in_avals is None
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)
elif in_avals is None:
avals = []
for i, a in enumerate(explicit_args):
try:
avals.append(shaped_abstractify(a))
except OverflowError as e:
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
else f"flattened argument number is {i}")
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from e
except TypeError as e:
arg_description = (f"path {dbg.arg_names[i]}" if dbg
else f"flattened argument number {i}")
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(a)} and was passed to"
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from e

in_type = in_avals = tuple(avals)
else:
in_type = in_avals
in_type = in_avals # type: ignore
assert in_avals is not None

in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
Expand All @@ -652,6 +627,7 @@ def _infer_params_impl(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
IgnoreKey(ji.inline))
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)

out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
Expand Down Expand Up @@ -693,7 +669,6 @@ def _infer_params_impl(
donated_invars, dbg.arg_names if dbg else None, len(consts),
attrs_tracked, abstract_mesh), args_flat


def get_abstract_mesh_from_avals(in_avals):
if not config.sharding_in_types.value:
return None
Expand All @@ -711,9 +686,7 @@ def get_abstract_mesh_from_avals(in_avals):
class InferParamsCacheEntry:
"""Mutable value object for _infer_params_cached."""
__slots__ = ['pjit_params']

pjit_params: PjitParams | None

def __init__(self):
self.pjit_params = None

Expand Down Expand Up @@ -747,34 +720,76 @@ def _infer_params(
resource_env = None
pjit_mesh = None

skip_cache = config.dynamic_shapes.value
if not skip_cache:
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
try:
avals = tuple(shaped_abstractify(a) for a in dynargs)
except (OverflowError, TypeError):
# If we see something we don't understand, use the slow path.
skip_cache = True

if skip_cache:
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
kwargs, in_avals=None)
return p, p.consts + args_flat

entry = _infer_params_cached(
fun, ji, signature, avals, pjit_mesh, resource_env)
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
avals = _infer_input_type(fun, dbg, dynargs)
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
if entry.pjit_params is None:
p, args_flat = _infer_params_impl(
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
if p.attrs_tracked:
# If there are attrs_tracked, don't use the cache.
if p.attrs_tracked: # if attrs, don't popoulate the cache
return p, p.consts + args_flat
else:
entry.pjit_params = p
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs

def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]:
avals = []
try:
for i, x in enumerate(explicit_args):
avals.append(shaped_abstractify(x))
except OverflowError:
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg # type: ignore
else f"flattened argument number is {i}") # type: ignore
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from None
except TypeError:
arg_description = (f"path {dbg.arg_names[i]}" if dbg # type: ignore
else f"flattened argument number {i}") # type: ignore
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(x)} and was passed to" # type: ignore
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from None
if config.mutable_array_checks.value:
# TODO(mattjj): make this faster
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, explicit_args)):
if (isinstance(a, AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
f"the mutable array reference of type {a.str_short()} appeared at both "
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
return tuple(avals)

def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
if not config.mutable_array_checks.value: return
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
for i, x in enumerate(args):
if id(core.get_referent(x)) in refs:
a = shaped_abstractify(x)
raise ValueError(
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")

def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
Expand Down
1 change: 1 addition & 0 deletions tests/attrs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def body(x, __):
self.assertAllClose(thing.x, 1024., check_dtypes=False)

def test_arg_to_jit(self):
self.skipTest("regressed this experimental feature") # TODO(mattjj)
thing = Thing(1.0)
count = 0

Expand Down
Loading

0 comments on commit 96d4a75

Please sign in to comment.