Skip to content

Commit

Permalink
Merge pull request #25833 from jax-ml:more-linearize-fixes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715073762
  • Loading branch information
Google-ML-Automation committed Jan 13, 2025
2 parents e72c148 + 96769f9 commit dabe27b
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 9 deletions.
61 changes: 52 additions & 9 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,15 @@ def jvpfun(f, instantiate, transform_stack, primals, tangents):
def linearize_subtrace(_f, _store, _tag, nzs_in, *primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval())
for (p, nz) in zip(primals, nzs_in) if nz]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [LinearizeTracer(linearize_trace, p,
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
if nz else p
for p, nz in zip(primals, nzs_in)]
with core.set_current_trace(linearize_trace):
ans = _f(*tracers)
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
nzs_out = [type(t) is not Zero for t in out_tangents]
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = [t for t, nz in zip(out_tangents, nzs_out) if nz]
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
Expand Down Expand Up @@ -135,6 +136,10 @@ def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
effects=jaxpr.effects, debug_info=dbg)

def linearize_jaxpr(jaxpr, nonzeros):
return _linearize_jaxpr(jaxpr, tuple(nonzeros))

@weakref_lru_cache
def _linearize_jaxpr(jaxpr, nonzeros):
primal_trace = pe.DynamicJaxprTrace()
tangent_trace = pe.DynamicJaxprTrace()
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
Expand All @@ -154,11 +159,13 @@ def new_arg(primal_aval, nz):
out_tangents = [tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz]
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_trace.invalidate()
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
if attrs_tracked:
Expand Down Expand Up @@ -187,6 +194,7 @@ def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None):
out_tangents = map(instantiate_zeros, out_tangents)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
Expand Down Expand Up @@ -551,6 +559,7 @@ def _primal_tangent_shapes_match(primal, tangent):
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)

call_param_updaters: dict[core.Primitive, Callable] = {}
call_linearize_param_updaters: dict[core.Primitive, Callable] = {}
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}

# -------------------- Linearize trace --------------------
Expand Down Expand Up @@ -637,26 +646,60 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
nzs_in = [type(t) is not Zero for t in tangents]
nzs_in = tuple(type(t) is not Zero for t in tangents)
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in)
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), params)
if isinstance(call_primitive, core.MapPrimitive):
@as_hashable_function(closure=(linearize_outs_thunk))
def new_out_axes_thunk():
num_residuals, _, _ = linearize_outs_thunk()
out_axes = params['out_axes_thunk']()
return (*(0 for _ in range(num_residuals)), *out_axes)
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
else:
primal_params = params

all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk()
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]

if isinstance(call_primitive, core.MapPrimitive):
in_axes = params['in_axes']
out_axes = params['out_axes_thunk']()
residual_avals = map(get_aval, residuals)
new_in_axes = (*(0 for _ in residual_avals),
*(ax for ax, nz in zip(in_axes, nzs_in) if nz))
new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),)
# NOTE: This assumes that the output tangents being zero is a
# deterministic function of which input tangents were zero.
@as_hashable_function(closure=(new_out_axes))
def new_out_axes_thunk():
return new_out_axes
params = dict(params,
in_axes=new_in_axes,
out_axes_thunk=new_out_axes_thunk)

update_params = call_linearize_param_updaters.get(call_primitive)
new_params = update_params(params, residual_avals, nzs_in) if update_params else params

def f_tangent(*args):
residuals = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), params)
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
return map(partial(maybe_linearize_tracer, self), primals_out, nzs_out, tangents_out)

# The only difference between process_map and process_call is that
# the `in_axes` and `out_axes_thunk` params must be updated;
# that's handled in process_call.
process_map = process_call

def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
if is_nonzero:
assert not type(tangent) is Zero
Expand Down Expand Up @@ -692,8 +735,8 @@ def make_zero(aval):
else:
zero_type = Zero

tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
for aval, nz in zip(tangent_avals, nonzeros)]
tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
for aval, nz in zip(tangent_avals, nonzeros))
with core.set_current_trace(trace):
out_primals, out_tangents = jvp(primals, tangent_args, **params)

Expand Down
7 changes: 7 additions & 0 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,12 @@ def xla_call_jvp_update_params(params, nz_tangents):
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)

def _xla_call_linearize_update_params(params, residual_avals, nz_tangents):
donated_invars_prev = params['donated_invars']
donated_invars = (*(False for _ in residual_avals),
*(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz))
return dict(params, donated_invars=donated_invars)

def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
Expand All @@ -1411,6 +1417,7 @@ def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
res_aval=_pmap_partial_eval_custom_res_maker)
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params
ad.call_linearize_param_updaters[xla_pmap_p] = _xla_call_linearize_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params

ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def _getattr_jvp(trace, obj, attr):
return getattr(obj, attr)
ad.JVPTrace.process_getattr = _getattr_jvp

ad.LinearizeTrace.process_setattr = _setattr_jvp
ad.LinearizeTrace.process_getattr = _getattr_jvp

def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
attr_primals = [jax_getattr(o, a) for o, a in attrs]
Expand Down
49 changes: 49 additions & 0 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,55 @@ def known_out_names():
return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval

def _shard_map_linearize(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto):
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in)
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
all_names = _all_mesh_names_except_spmd(mesh, trace)

@as_hashable_function(closure=(linearize_outs_thunk))
def primal_out_names_thunk():
num_residuals, _, _ = linearize_outs_thunk()
out_names = out_names_thunk()
return (*({0: all_names} for _ in range(num_residuals)), *out_names)
primal_params = dict(
mesh=mesh, in_names=in_names,
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
all_primal_results = shard_map_p.bind_with_trace(
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk()
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
residual_avals = map(core.get_aval, residuals)
out_names = out_names_thunk()
new_in_names = (*({0: all_names} for _ in residual_avals),
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
@as_hashable_function(closure=(new_out_names))
def tangent_out_names_thunk():
return new_out_names
tangent_params = dict(
mesh=mesh, in_names=new_in_names,
out_names_thunk=tangent_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)

def f_tangent(*args):
residuals = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
(lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), tangent_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out)
ad.LinearizeTrace.process_shard_map = _shard_map_linearize

@lu.transformation2
def _promote_scalar_residuals(f, *args, **kwargs):
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
Expand Down

0 comments on commit dabe27b

Please sign in to comment.