From 8fe8d241e87ba829db44377f5e3127fce244838b Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 11 Dec 2024 13:46:48 -0500 Subject: [PATCH] Fixes to direct linearize * Fix a bug in pjit linearization rule * Handle multiple results and zeros in fallback rule * Handle `has_aux` * Implement process_custom_vjp_call --- jax/_src/interpreters/ad.py | 112 ++++++++++++++++++++++++++++-------- 1 file changed, 87 insertions(+), 25 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 3f6c5ee5b043..48b2d81d09e1 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -143,21 +143,31 @@ def new_arg(primal_aval, nz): def direct_linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) - assert not has_aux with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace() tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] linearize_trace = LinearizeTrace(parent_trace, tangent_trace) tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] with core.set_current_trace(linearize_trace): - ans = traceable.call_wrapped(*tracers) - + if has_aux: + ans, aux = traceable.call_wrapped(*tracers) + aux_primals = [x.primal + if isinstance(x, LinearizeTracer) + and x._trace.tag is linearize_trace.tag + else x for x in aux] + else: + ans = traceable.call_wrapped(*tracers) + aux = None out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) + out_tangents = map(instantiate_zeros, out_tangents) out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] del attrs_tracked # TODO: attrs - return out_primals, out_tangents_pvals, jaxpr, consts + if has_aux: + return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals + else: + return out_primals, out_tangents_pvals, jaxpr, consts def linearize(traceable, *primals, **kwargs): if config.use_direct_linearize.value: @@ -532,22 +542,45 @@ def to_primal_tangent_pair(self, val): def process_primitive(self, primitive, args, params): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) - tangent_nonzeros = [type(t) is not Zero for t in tangents_in] + tangent_nzs = [type(t) is not Zero for t in tangents_in] if all(type(t) is Zero for t in tangents_in): return primitive.bind_with_trace(self.parent_trace, primals_in, params) - lin = primitive_linearizations.get(primitive) - if lin is None: - lin = partial(fallback_linearize_rule, primitive) + fallback = partial(fallback_linearize_rule, primitive) + lin = primitive_linearizations.get(primitive, fallback) with core.set_current_trace(self.parent_trace): - primal_out, tangent_nonzeros_out, residuals, linearized = lin( - tangent_nonzeros, *primals_in, **params) + primal_out, tangent_nzs_out, residuals, linearized = lin( + tangent_nzs, *primals_in, **params) with core.set_current_trace(self.tangent_trace): tangent_out = linearized(residuals, *tangents_in) if primitive.multiple_results: return [maybe_linearize_tracer(self, x, nz, t) - for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)] + for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)] else: - return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out) + return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out) + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return prim.bind_with_trace(self.parent_trace, + (fun, fwd, bwd, *primals_in), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) + fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] + fwd_in = [x for pair in fwd_in for x in pair] # flatten + with core.set_current_trace(self.parent_trace): + res_and_primals_out = fwd.call_wrapped(*fwd_in) + + _, res_tree = out_trees() + res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] + + with core.set_current_trace(self.tangent_trace): + tangents_in = map(instantiate_zeros, tangents_in) + tangents_out = custom_lin_p.bind( + *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros) + tangent_nzs_out = [type(t) is not Zero for t in tangents_out] + return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): if is_nonzero: @@ -557,21 +590,50 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): assert type(tangent) is Zero return primal -def fallback_linearize_rule(prim, _, *args, **kwargs): - assert not prim.multiple_results - - def call_prim(*args_): - return [prim.bind(*args_, **kwargs)] - - with config.use_direct_linearize(False): - (out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize( - lu.wrap_init(call_prim), *args, **kwargs) +def fallback_linearize_rule(prim, nonzeros, *primals, **params): + jvp = primitive_jvps.get(prim) + if not jvp: + msg = f"Differentiation rule for '{prim}' not implemented" + raise NotImplementedError(msg) + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag()) + tangent_avals = [get_aval(p).to_tangent_aval() for p in primals] + tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else Zero(aval) + for aval, nz in zip(tangent_avals, nonzeros)] + with core.set_current_trace(trace): + out_primals, out_tangents = jvp(primals, tangent_args, **params) + + if not prim.multiple_results: + out_primals = [out_primals] + out_tangents = [out_tangents] + + out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals] + out_nzs = [type(r) is not Zero for r in out_tangents] + out_tangent_avals = [get_aval(p).to_tangent_aval() for p in out_primals] + out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] + in_tracers = [t for t in tangent_args if type(t) is not Zero] + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers) + + def linearized(residuals, *tangents): + nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] + nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in) + nz_tangents_out_iter = iter(nz_tangents_out) + all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval) + for (aval, nz) in zip(out_tangent_avals, out_nzs)] + if prim.multiple_results: + return all_out_tangents + else: + out_tangent, = all_out_tangents + return out_tangent - def linearized(residuals, *tangents): - out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents) - return out_tangent + if prim.multiple_results: + return out_primals, out_nzs, out_consts, linearized + else: + out_primal, = out_primals + out_nz, = out_nzs + return out_primal, out_nz, out_consts, linearized - return out_primal, True, consts, linearized class LinearizeTracer(Tracer): __slots__ = ['primal', 'tangent']