Skip to content

Commit

Permalink
Fixes to direct linearize
Browse files Browse the repository at this point in the history
  * Fix a bug in pjit linearization rule
  * Handle multiple results and zeros in fallback rule
  * Handle `has_aux`
  * Implement process_custom_vjp_call
  • Loading branch information
dougalm committed Dec 12, 2024
1 parent 20236f1 commit 8fe8d24
Showing 1 changed file with 87 additions and 25 deletions.
112 changes: 87 additions & 25 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,31 @@ def new_arg(primal_aval, nz):

def direct_linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
assert not has_aux
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
with core.set_current_trace(linearize_trace):
ans = traceable.call_wrapped(*tracers)

if has_aux:
ans, aux = traceable.call_wrapped(*tracers)
aux_primals = [x.primal
if isinstance(x, LinearizeTracer)
and x._trace.tag is linearize_trace.tag
else x for x in aux]
else:
ans = traceable.call_wrapped(*tracers)
aux = None
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
out_tangents = map(instantiate_zeros, out_tangents)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
del attrs_tracked # TODO: attrs
return out_primals, out_tangents_pvals, jaxpr, consts
if has_aux:
return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals
else:
return out_primals, out_tangents_pvals, jaxpr, consts

def linearize(traceable, *primals, **kwargs):
if config.use_direct_linearize.value:
Expand Down Expand Up @@ -532,22 +542,45 @@ def to_primal_tangent_pair(self, val):

def process_primitive(self, primitive, args, params):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
tangent_nonzeros = [type(t) is not Zero for t in tangents_in]
tangent_nzs = [type(t) is not Zero for t in tangents_in]
if all(type(t) is Zero for t in tangents_in):
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
lin = primitive_linearizations.get(primitive)
if lin is None:
lin = partial(fallback_linearize_rule, primitive)
fallback = partial(fallback_linearize_rule, primitive)
lin = primitive_linearizations.get(primitive, fallback)
with core.set_current_trace(self.parent_trace):
primal_out, tangent_nonzeros_out, residuals, linearized = lin(
tangent_nonzeros, *primals_in, **params)
primal_out, tangent_nzs_out, residuals, linearized = lin(
tangent_nzs, *primals_in, **params)
with core.set_current_trace(self.tangent_trace):
tangent_out = linearized(residuals, *tangents_in)
if primitive.multiple_results:
return [maybe_linearize_tracer(self, x, nz, t)
for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)]
for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)]
else:
return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out)
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)

def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace,
(fun, fwd, bwd, *primals_in),
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
fwd_in = [x for pair in fwd_in for x in pair] # flatten
with core.set_current_trace(self.parent_trace):
res_and_primals_out = fwd.call_wrapped(*fwd_in)

_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]

with core.set_current_trace(self.tangent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)

def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
if is_nonzero:
Expand All @@ -557,21 +590,50 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
assert type(tangent) is Zero
return primal

def fallback_linearize_rule(prim, _, *args, **kwargs):
assert not prim.multiple_results

def call_prim(*args_):
return [prim.bind(*args_, **kwargs)]

with config.use_direct_linearize(False):
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
lu.wrap_init(call_prim), *args, **kwargs)
def fallback_linearize_rule(prim, nonzeros, *primals, **params):
jvp = primitive_jvps.get(prim)
if not jvp:
msg = f"Differentiation rule for '{prim}' not implemented"
raise NotImplementedError(msg)
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
tangent_avals = [get_aval(p).to_tangent_aval() for p in primals]
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else Zero(aval)
for aval, nz in zip(tangent_avals, nonzeros)]
with core.set_current_trace(trace):
out_primals, out_tangents = jvp(primals, tangent_args, **params)

if not prim.multiple_results:
out_primals = [out_primals]
out_tangents = [out_tangents]

out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
out_nzs = [type(r) is not Zero for r in out_tangents]
out_tangent_avals = [get_aval(p).to_tangent_aval() for p in out_primals]
out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t in tangent_args if type(t) is not Zero]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)

def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in)
nz_tangents_out_iter = iter(nz_tangents_out)
all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval)
for (aval, nz) in zip(out_tangent_avals, out_nzs)]
if prim.multiple_results:
return all_out_tangents
else:
out_tangent, = all_out_tangents
return out_tangent

def linearized(residuals, *tangents):
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
return out_tangent
if prim.multiple_results:
return out_primals, out_nzs, out_consts, linearized
else:
out_primal, = out_primals
out_nz, = out_nzs
return out_primal, out_nz, out_consts, linearized

return out_primal, True, consts, linearized

class LinearizeTracer(Tracer):
__slots__ = ['primal', 'tangent']
Expand Down

0 comments on commit 8fe8d24

Please sign in to comment.