From 96769f96c217b39da1b0c376020f859f8a18d921 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 17 Dec 2024 19:54:42 -0500 Subject: [PATCH] Even more linearize fixes --- jax/_src/interpreters/ad.py | 61 +++++++++++++++++++++++++++++------ jax/_src/interpreters/pxla.py | 7 ++++ jax/experimental/attrs.py | 2 ++ jax/experimental/shard_map.py | 49 ++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 9 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 85ade5e021f0..b26d59db6bef 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -86,14 +86,15 @@ def jvpfun(f, instantiate, transform_stack, primals, tangents): def linearize_subtrace(_f, _store, _tag, nzs_in, *primals, **params): with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace() - tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) - for (p, nz) in zip(primals, nzs_in) if nz] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) - tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] + tracers = [LinearizeTracer(linearize_trace, p, + tangent_trace.new_arg(get_aval(p).to_tangent_aval())) + if nz else p + for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace): ans = _f(*tracers) out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) - nzs_out = [type(t) is not Zero for t in out_tangents] + nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = [t for t, nz in zip(out_tangents, nzs_out) if nz] out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) @@ -135,6 +136,10 @@ def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: effects=jaxpr.effects, debug_info=dbg) def linearize_jaxpr(jaxpr, nonzeros): + return _linearize_jaxpr(jaxpr, tuple(nonzeros)) + +@weakref_lru_cache +def _linearize_jaxpr(jaxpr, nonzeros): primal_trace = pe.DynamicJaxprTrace() tangent_trace = pe.DynamicJaxprTrace() lin_trace = LinearizeTrace(primal_trace, tangent_trace) @@ -154,11 +159,13 @@ def new_arg(primal_aval, nz): out_tangents = [tangent_trace.to_jaxpr_tracer(t) for (nz, t) in zip(nzs_out, out_tangents) if nz] tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") residuals_and_primals = (*tangent_consts, *out_primals) residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals) + primal_trace.invalidate() num_residuals = len(tangent_consts) tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) if attrs_tracked: @@ -187,6 +194,7 @@ def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None): out_tangents = map(instantiate_zeros, out_tangents) out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + tangent_trace.invalidate() out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] if attrs_tracked: raise NotImplementedError("TODO: attrs") @@ -551,6 +559,7 @@ def _primal_tangent_shapes_match(primal, tangent): assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) call_param_updaters: dict[core.Primitive, Callable] = {} +call_linearize_param_updaters: dict[core.Primitive, Callable] = {} call_transpose_param_updaters: dict[core.Primitive, Callable] = {} # -------------------- Linearize trace -------------------- @@ -637,13 +646,42 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) - nzs_in = [type(t) is not Zero for t in tangents] + nzs_in = tuple(type(t) is not Zero for t in tangents) f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in) - all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), params) + if isinstance(call_primitive, core.MapPrimitive): + @as_hashable_function(closure=(linearize_outs_thunk)) + def new_out_axes_thunk(): + num_residuals, _, _ = linearize_outs_thunk() + out_axes = params['out_axes_thunk']() + return (*(0 for _ in range(num_residuals)), *out_axes) + primal_params = dict(params, out_axes_thunk=new_out_axes_thunk) + else: + primal_params = params + + all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params) num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk() residuals = all_primal_results[:num_residuals] primals_out = all_primal_results[num_residuals:] + if isinstance(call_primitive, core.MapPrimitive): + in_axes = params['in_axes'] + out_axes = params['out_axes_thunk']() + residual_avals = map(get_aval, residuals) + new_in_axes = (*(0 for _ in residual_avals), + *(ax for ax, nz in zip(in_axes, nzs_in) if nz)) + new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),) + # NOTE: This assumes that the output tangents being zero is a + # deterministic function of which input tangents were zero. + @as_hashable_function(closure=(new_out_axes)) + def new_out_axes_thunk(): + return new_out_axes + params = dict(params, + in_axes=new_in_axes, + out_axes_thunk=new_out_axes_thunk) + + update_params = call_linearize_param_updaters.get(call_primitive) + new_params = update_params(params, residual_avals, nzs_in) if update_params else params + def f_tangent(*args): residuals = args[:num_residuals] nz_tangents = args[num_residuals:] @@ -651,12 +689,17 @@ def f_tangent(*args): nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = call_primitive.bind_with_trace( - self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), params) + self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal) for nz, primal in zip(nzs_out, primals_out)] return map(partial(maybe_linearize_tracer, self), primals_out, nzs_out, tangents_out) + # The only difference between process_map and process_call is that + # the `in_axes` and `out_axes_thunk` params must be updated; + # that's handled in process_call. + process_map = process_call + def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): if is_nonzero: assert not type(tangent) is Zero @@ -692,8 +735,8 @@ def make_zero(aval): else: zero_type = Zero - tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval) - for aval, nz in zip(tangent_avals, nonzeros)] + tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval) + for aval, nz in zip(tangent_avals, nonzeros)) with core.set_current_trace(trace): out_primals, out_tangents = jvp(primals, tangent_args, **params) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 150db8a44fb6..6366e997270f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1396,6 +1396,12 @@ def xla_call_jvp_update_params(params, nz_tangents): new_donated_invars = (*donated_invars, *donated_tangents) return dict(params, donated_invars=new_donated_invars) +def _xla_call_linearize_update_params(params, residual_avals, nz_tangents): + donated_invars_prev = params['donated_invars'] + donated_invars = (*(False for _ in residual_avals), + *(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz)) + return dict(params, donated_invars=donated_invars) + def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): donated_invars = params['donated_invars'] donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] @@ -1411,6 +1417,7 @@ def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): res_aval=_pmap_partial_eval_custom_res_maker) pe.dce_rules[xla_pmap_p] = _pmap_dce_rule ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params +ad.call_linearize_param_updaters[xla_pmap_p] = _xla_call_linearize_update_params ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 59531ae561cb..1706eaa63312 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -148,6 +148,8 @@ def _getattr_jvp(trace, obj, attr): return getattr(obj, attr) ad.JVPTrace.process_getattr = _getattr_jvp +ad.LinearizeTrace.process_setattr = _setattr_jvp +ad.LinearizeTrace.process_getattr = _getattr_jvp def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): attr_primals = [jax_getattr(o, a) for o, a in attrs] diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 05f44f3134e7..b6bd84e32ef0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1531,6 +1531,55 @@ def known_out_names(): return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval +def _shard_map_linearize(trace, shard_map_p, f, tracers, mesh, in_names, + out_names_thunk, check_rep, rewrite, auto): + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + nzs_in = tuple(type(t) is not ad.Zero for t in tangents) + f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in) + tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] + all_names = _all_mesh_names_except_spmd(mesh, trace) + + @as_hashable_function(closure=(linearize_outs_thunk)) + def primal_out_names_thunk(): + num_residuals, _, _ = linearize_outs_thunk() + out_names = out_names_thunk() + return (*({0: all_names} for _ in range(num_residuals)), *out_names) + primal_params = dict( + mesh=mesh, in_names=in_names, + out_names_thunk=primal_out_names_thunk, check_rep=check_rep, + rewrite=rewrite, auto=auto) + all_primal_results = shard_map_p.bind_with_trace( + trace.parent_trace, (f_primal,) + tuple(primals), primal_params) + num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk() + residuals = all_primal_results[:num_residuals] + primals_out = all_primal_results[num_residuals:] + residual_avals = map(core.get_aval, residuals) + out_names = out_names_thunk() + new_in_names = (*({0: all_names} for _ in residual_avals), + *(ax for ax, nz in zip(in_names, nzs_in) if nz)) + new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),) + @as_hashable_function(closure=(new_out_names)) + def tangent_out_names_thunk(): + return new_out_names + tangent_params = dict( + mesh=mesh, in_names=new_in_names, + out_names_thunk=tangent_out_names_thunk, check_rep=check_rep, + rewrite=rewrite, auto=auto) + + def f_tangent(*args): + residuals = args[:num_residuals] + nz_tangents = args[num_residuals:] + return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents) + + nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] + nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace, + (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), tangent_params) + nz_tangents_out_iter = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) + for nz, primal in zip(nzs_out, primals_out)] + return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) +ad.LinearizeTrace.process_shard_map = _shard_map_linearize + @lu.transformation2 def _promote_scalar_residuals(f, *args, **kwargs): jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)