Skip to content

Commit

Permalink
Initial AD impl
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed May 8, 2023
1 parent 2d9aa38 commit 02184d7
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 12 deletions.
319 changes: 316 additions & 3 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Module for calling pallas functions from JAX."""
from functools import partial
import itertools as it
import operator as op

from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple, Union

Expand All @@ -30,11 +31,12 @@
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src.lib.mlir.dialects import mhlo
from jax._src import source_info_util
from jax._src import state
from jax._src.state import discharge as state_discharge
from jax._src.util import (
split_list, safe_map, safe_zip, weakref_lru_cache,
tuple_insert, partition_list)
tuple_insert, partition_list, merge_lists)
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -141,6 +143,8 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_):
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
input_output_aliases: Tuple[Tuple[int, int], ...],
in_shapes, out_shapes, grid_spec, debug, interpret, **compiler_params: Any):
num_inputs = len(in_shapes)
num_outputs = len(out_shapes)
if input_output_aliases:
raise NotImplementedError("JVP with aliasing not supported.")
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
Expand All @@ -151,7 +155,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
jvp_which_linear = which_linear + (True,) * len(tangents)
jvp_which_linear = (*which_linear, *(True,) * len(tangents))
jvp_inshapes = (*in_shapes, *in_shapes)
jvp_outshapes = (*out_shapes, *out_shapes)
if input_output_aliases:
Expand Down Expand Up @@ -190,6 +194,316 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule

_save_everything = lambda *_, **__: True

def _convert_outputs_to_writes(
jaxpr: jax_core.Jaxpr,
) -> Tuple[jax_core.Jaxpr, list[jax_core.ShapedArray]]:
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."

in_avals = [v.aval for v in jaxpr.invars] # [*orig_ref_avals]
@lu.wrap_init
def eval_jaxpr(*refs):
# We split the refs into the original input refs and the dummy residual
# refs.
orig_refs, residual_refs = split_list(refs, [len(in_avals)])
residual_vals = jax_core.eval_jaxpr(jaxpr, (), *orig_refs)
for res_ref, res_val in zip(residual_refs, residual_vals):
res_ref[()] = res_val
return []
res_ref_avals = [state.ShapedArrayRef(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error
for v in jaxpr.outvars]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [jax_core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]

def _convert_inputs_to_reads(num_res: int, jaxpr: jax_core.Jaxpr
) -> jax_core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars"

@lu.wrap_init
def eval_jaxpr(*refs):
residual_refs, orig_refs = split_list(refs, [num_res])
residual_vals = [r[()] for r in residual_refs]
() = jax_core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs)
return []

res_val_avals, orig_ref_avals = split_list([v.aval for v in jaxpr.invars], [num_res])
res_ref_avals = [state.ShapedArrayRef(aval.shape, aval.dtype)
for aval in res_val_avals]

jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
return jaxpr

def _pallas_call_partial_eval(
trace: pe.JaxprTrace,
*tracers: pe.JaxprTracer,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_spec: pallas_core.GridSpec,
which_linear: tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
**compiler_params: Any):
if input_output_aliases:
raise NotImplementedError
num_inputs = len(in_shapes)
num_outputs = len(out_shapes)
assert num_inputs + num_outputs == len(jaxpr.invars)
in_unknowns = [not t.pval.is_known() for t in tracers]
out_unknowns = [False] * num_outputs
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
# after running the for loop. We want to use the jaxpr to determine which
# `Ref`s are unknown after executing the for loop body given which `Ref`s are
# unknown before. However, the jaxpr has no outputs. Instead, we discharge
# the body and run the fixpoint with the discharged jaxpr. We can do this
# because the outputs of the jaxpr are one-to-one with the inputs.
all_in_unknowns = [*in_unknowns, *out_unknowns]
discharged_jaxpr, discharged_consts = state.discharge_state(jaxpr, ())
discharged_jaxpr = discharged_jaxpr.replace(
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
constvars=[])
for _ in range(num_inputs + num_outputs):
jaxpr_in_unknowns = [False] * len(discharged_consts) + all_in_unknowns
_, _, all_out_unknowns, _, _, = pe.partial_eval_jaxpr_custom(
discharged_jaxpr, jaxpr_in_unknowns, [True] * len(jaxpr_in_unknowns),
all_in_unknowns, False, _save_everything)
all_out_unknowns = list(all_out_unknowns)
if all_out_unknowns == all_in_unknowns:
break
all_in_unknowns = map(op.or_, all_in_unknowns, all_out_unknowns)
else:
raise Exception("Invalid fixpoint")
all_unknowns = all_in_unknowns
del all_in_unknowns, all_out_unknowns # redundant since it's the same as `in_unknowns`
in_unknowns, out_unknowns = split_list(all_unknowns, [num_inputs])

tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore
for t, uk in zip(tracers, in_unknowns))

# We use `partial_eval_jaxpr_custom` here because it won't remove effectful
# primitives like `get`/`set`.
jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr,
in_inst=all_unknowns,
in_unknowns=all_unknowns,
ensure_out_unknowns=[],
ensure_out_inst=[],
saveable=_save_everything)
# # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
# regular valued input/outputs. However, we'd like to bind these jaxprs to a
# `for`, which expects only `Ref` inputs and no output. We need to convert
# both of these jaxprs into ones that are compatible with `for`.
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
# TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of
# passing the loop index as a residual

# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
# to output residual values (none of them should be `Ref`s). We'll need to
# convert the output residual values into `Ref`s that are initially empty
# `Ref`s that are written to at the end of the jaxpr.
jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout)
jaxpr_unknown = _convert_inputs_to_reads(num_res, jaxpr_unknown_resin_)

# Now we execute the forward pass that returns known outputs and residuals
grid, block_mappings, mapped_dims = (
grid_spec.grid, grid_spec.block_mappings, grid_spec.mapped_dims)
in_block_mappings, out_block_mappings = split_list(block_mappings,
[num_inputs])
known_in_block_mappings, unknown_in_block_mappings = partition_list(
in_unknowns, in_block_mappings)
known_out_block_mappings, unknown_out_block_mappings = partition_list(
out_unknowns, out_block_mappings)
known_in_shapes, unknown_in_shapes = partition_list(in_unknowns,
in_shapes)
known_out_shapes, unknown_out_shapes = partition_list(out_unknowns,
out_shapes)
known_which_linear, unknown_which_linear = partition_list(in_unknowns,
which_linear)
res_which_linear = (False,) * num_res
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.get_known() for t in known_tracers]
res_shapes = [jax.ShapeDtypeStruct((*grid, *a.shape), a.dtype)
for a in res_avals]
res_index_mappings = [
jax_core.ClosedJaxpr(
pe.trace_to_jaxpr_dynamic(
lu.wrap_init(lambda *args: (*args, *[0] * len(a.shape))),
[jax_core.ShapedArray((), jnp.int32)] *len(grid))[0], ())
for a in res_avals
]
res_block_mappings = [
BlockMapping((*[None] * len(grid), *a.shape), index_map)
for a, index_map in zip(res_avals, res_index_mappings)
]
known_grid_spec = GridSpec(grid, (*known_in_block_mappings,
*known_out_block_mappings,
*res_block_mappings),
grid_spec.mapped_dims)
unknown_grid_spec = GridSpec(grid, (*res_block_mappings,
*unknown_in_block_mappings,
*unknown_out_block_mappings),
grid_spec.mapped_dims)
known_out_and_res = pallas_call_p.bind(
*known_vals,
jaxpr=jaxpr_known,
grid_spec=known_grid_spec,
in_shapes=tuple(known_in_shapes),
out_shapes=(*known_out_shapes, *res_shapes),
interpret=interpret,
debug=debug,
name=f"{name}_known",
input_output_aliases=(),
which_linear=tuple(known_which_linear),
**compiler_params)
known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)])
residuals = map(trace.new_instantiated_const, residuals)
unknown_inputs = [*residuals, *unknown_tracers]
unknown_outputs = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(jax_core.ShapedArray(s.shape,
s.dtype)), None)
for s in unknown_out_shapes]
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
unknown_params = dict(
jaxpr=jaxpr_unknown,
in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals),
*unknown_in_shapes),
out_shapes=tuple(unknown_out_shapes),
grid_spec=unknown_grid_spec,
which_linear=(*res_which_linear, *unknown_which_linear),
debug=debug,
interpret=interpret,
name=f"{name}_unknown",
input_output_aliases=(),
**compiler_params)
eqn = pe.new_eqn_recipe(unknown_inputs, unknown_outputs,
pallas_call_p, unknown_params,
jax_core.no_effects, source)
for t in unknown_outputs: t.recipe = eqn
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval

def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool]
) -> jax_core.Jaxpr:
num_inputs = len(which_linear)
num_outputs = len(jaxpr.invars) - num_inputs
def trans(*args):
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.
res_jaxpr, tangent_jaxpr_, *_ = \
pe.partial_eval_jaxpr_custom(jaxpr,
in_unknowns=[*which_linear, *[True] *
num_outputs],
in_inst=[*which_linear, *[True] *
num_outputs],
ensure_out_inst=[],
ensure_out_unknowns=[],
saveable=_save_everything)
res_args = [x for x, lin in zip(args, which_linear) if not lin]
res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args)

# Now that we have residual values, we run the tangent jaxpr. It takes as
# input the residuals, and all the refs (at least, the ones
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
# unknown inputs!
breakpoint()
primals_args = [*(r for u, r in zip(used_res, res) if u)]
ct_args = [x for x, u in zip(args, used_ct) if u]
ad.backward_pass(
tangent_jaxpr, (), False, (), (*res, *ct_args), ())
breakpoint()
return []
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans

def _pallas_call_transpose_rule(cts_in, *args,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
grid_spec: GridSpec,
input_output_aliases: Tuple[Tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: Tuple[bool, ...],
**compiler_params: Any):
num_inputs = len(in_shapes)
num_outputs = len(out_shapes)
is_undefined_primal = [ad.is_undefined_primal(x) for x in args]
defined_primals, undefined_primals = partition_list(is_undefined_primal, args)
defined_in_shapes, undefined_in_shapes = partition_list(is_undefined_primal,
in_shapes)
block_mappings = grid_spec.block_mappings
in_block_mappings, out_block_mappings = split_list(block_mappings,
[num_inputs])
defined_in_block_mappings, undefined_in_block_mappings = partition_list(
is_undefined_primal, in_block_mappings)
defined_which_linear, undefined_which_linear = partition_list(
is_undefined_primal, which_linear)
defined_in_shapes, undefined_in_shapes = partition_list(is_undefined_primal,
in_shapes)
num_undefined_inputs = sum(is_undefined_primal)
num_defined_inputs = num_inputs - num_undefined_inputs
def trans(*args):
defined_primals, cts, undefined_primals = split_list(args,
[num_defined_inputs,
num_outputs])
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.
res_jaxpr, tangent_jaxpr_, *_ = \
pe.partial_eval_jaxpr_custom(jaxpr,
in_unknowns=[*is_undefined_primal, *[True] *
num_outputs],
in_inst=[*is_undefined_primal, *[True] *
num_outputs],
ensure_out_inst=[],
ensure_out_unknowns=[],
saveable=_save_everything)
res = jax_core.eval_jaxpr(res_jaxpr, (), *defined_primals)

# Now that we have residual values, we run the tangent jaxpr. It takes as
# input the residuals, and all the refs (at least, the ones
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
# unknown inputs!
ad.backward_pass(
tangent_jaxpr_, (), False, (), (*res, *undefined_primals, *cts), ())
return []
jaxpr_avals = [v.aval for v in jaxpr.invars]
jaxpr_in_avals, jaxpr_out_avals = split_list(jaxpr_avals, [num_inputs])
jaxpr_defined_in_avals, jaxpr_undefined_in_avals = partition_list(
is_undefined_primal, jaxpr_in_avals)
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [*jaxpr_defined_in_avals, *jaxpr_out_avals,
*jaxpr_undefined_in_avals])
grid_spec = GridSpec(
grid_spec.grid, (*defined_in_block_mappings, *out_block_mappings,
*undefined_in_block_mappings),
grid_spec.mapped_dims)
cts_out = pallas_call_p.bind(
*defined_primals, *cts_in,
jaxpr=jaxpr_trans,
grid_spec=grid_spec,
in_shapes=(*defined_in_shapes, *out_shapes),
out_shapes=tuple(undefined_in_shapes),
name=f"{name}_transpose",
debug=debug,
interpret=interpret,
which_linear=(*defined_which_linear, *[True] * num_outputs),
input_output_aliases=(),
**compiler_params)
cts_out_iter = iter(cts_out)
return [next(cts_out_iter) if ud else None for
ud in is_undefined_primal]
ad.primitive_transposes[pallas_call_p] = _pallas_call_transpose_rule

def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
dim: Union[int, batching.NotMapped],
block_mapping: Optional[BlockMapping]) -> BlockMapping:
Expand Down Expand Up @@ -347,7 +661,6 @@ def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
out_specs = tuple(out_specs)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
for x in flat_out_shapes]
@jax.jit
def wrapped(*args):
flat_args, in_tree = tree_util.tree_flatten(args)
if grid is None:
Expand Down
18 changes: 9 additions & 9 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,15 +752,15 @@ def pallas_impl(x_ref, o_ref):
rtol=1e-5)
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2)

# TODO(sharadmv): enable this when we update Triton
# def test_jvp_matmul(self):
# k1, k2 = random.split(random.PRNGKey(0))
# x = random.normal(k1, (256, 128))
# y = random.normal(k2, (128, 64))
# bm, bn, bk, gm = 64, 128, 32, 8
# mm = functools.partial(matmul, bm=bm, bn=bn, bk=bk, gm=gm,
# interpret=self.INTERPRET)
# jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)
TODO(sharadmv): enable this when we update Triton
def test_jvp_matmul(self):
k1, k2 = random.split(random.PRNGKey(0))
x = random.normal(k1, (256, 128))
y = random.normal(k2, (128, 64))
bm, bn, bk, gm = 64, 128, 32, 8
mm = functools.partial(matmul, bm=bm, bn=bn, bk=bk, gm=gm,
interpret=self.INTERPRET)
jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)

def test_slicing_block_spec(self):
@functools.partial(
Expand Down

0 comments on commit 02184d7

Please sign in to comment.