Skip to content

Commit

Permalink
WIP AD
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Feb 13, 2023
1 parent 6e6ec22 commit e05bd7c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 56 deletions.
22 changes: 18 additions & 4 deletions jax_triton/pallas/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src.lax.control_flow import for_loop
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import core as jax_core
from jax._src import state
from jax._src.state import primitives as sp
Expand Down Expand Up @@ -572,9 +573,10 @@ def _addupdate_lowering_rule(ctx: TritonLoweringRuleContext, ptr, value,
else slc for s, slc in zip(avals_in[0].shape, idx))
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
ptr = _offset_ptr(ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder, is_scalar)
old_value = tl.load(ptr, mask=mask, _builder=ctx.builder)
tl.store(ptr, old_value.__add__(value, _builder=ctx.builder),
mask=mask, _builder=ctx.builder)
tl.atomic_add(ptr, value, _builder=ctx.builder)
# old_value = tl.load(ptr, mask=mask, _builder=ctx.builder)
# tl.store(ptr, old_value.__add__(value, _builder=ctx.builder),
# mask=mask, _builder=ctx.builder)
return []
triton_lowering_rules[sp.addupdate_p] = _addupdate_lowering_rule

Expand Down Expand Up @@ -621,9 +623,21 @@ def _reduce_argmin_lowering(ctx: TritonLoweringRuleContext, a, *, axes,
triton_lowering_rules[lax.argmin_p] = _reduce_argmin_lowering

def _xla_call_lowering_rule(ctx: TritonLoweringRuleContext, *args, call_jaxpr, **_):
return lower_jaxpr_to_triton_ir(ctx.context, call_jaxpr, *args)
return lower_jaxpr_to_triton_ir(ctx.context, call_jaxpr, ctx.block_infos, *args)
triton_lowering_rules[xla.xla_call_p] = _xla_call_lowering_rule

def _closed_call_lowering_rule(ctx: TritonLoweringRuleContext, *args, call_jaxpr, **_):
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
if consts:
raise NotImplementedError
return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args)
triton_lowering_rules[jax_core.closed_call_p] = _closed_call_lowering_rule

def _remat_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, **_):
return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args)
triton_lowering_rules[ad_checkpoint.remat_p] = _remat_lowering_rule


def _for_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr,
which_linear, nsteps, reverse, unroll):
current_bb = ctx.builder.get_insert_block()
Expand Down
90 changes: 46 additions & 44 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
jvp_which_linear = (*which_linear, *(True,) * len(tangents))
jvp_inshapes = (*in_shapes, *in_shapes)
_, nonzero_tangent_in_shapes = partition_list(nonzero_tangents, in_shapes)
jvp_inshapes = (*in_shapes, *nonzero_tangent_in_shapes)
jvp_outshapes = (*out_shapes, *out_shapes)
if input_output_aliases:
raise NotImplementedError("`input_output_aliases` jvp not supported.")
Expand All @@ -181,7 +182,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)])
logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)])
in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)])
new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms))
nonzero_in_bms, _ = partition_list(nonzero_tangents, in_bms)
new_bms = tuple((*in_bms, *nonzero_in_bms, *out_bms, *out_bms))
new_grid_spec = grid_spec.replace(block_mappings=new_bms)
jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs,
*logical_tangent_inputs,
Expand Down Expand Up @@ -300,12 +302,13 @@ def _pallas_call_partial_eval(
jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr,
in_inst=all_unknowns,
in_inst=True,
in_unknowns=all_unknowns,
ensure_out_unknowns=[],
ensure_out_inst=[],
saveable=_save_everything)
# # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
breakpoint()
# `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
# regular valued input/outputs. However, we'd like to bind these jaxprs to a
# `for`, which expects only `Ref` inputs and no output. We need to convert
# both of these jaxprs into ones that are compatible with `for`.
Expand Down Expand Up @@ -348,13 +351,13 @@ def _pallas_call_partial_eval(
for a in res_avals
]
res_block_mappings = [
BlockMapping((*[None] * len(grid), *a.shape), index_map)
BlockMapping((*[pallas_core.mapped] * len(grid), *a.shape), index_map)
for a, index_map in zip(res_avals, res_index_mappings)
]
known_grid_spec = GridSpec(grid, (*known_in_block_mappings,
*known_out_block_mappings,
*res_block_mappings),
grid_spec.mapped_dims)
mapped_dims)
unknown_grid_spec = GridSpec(grid, (*res_block_mappings,
*unknown_in_block_mappings,
*unknown_out_block_mappings),
Expand All @@ -371,7 +374,7 @@ def _pallas_call_partial_eval(
input_output_aliases=(),
which_linear=tuple(known_which_linear),
**compiler_params)
known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)])
known_outputs, residuals = split_list(known_out_and_res, [len(known_out_shapes)])
residuals = map(trace.new_instantiated_const, residuals)
unknown_inputs = [*residuals, *unknown_tracers]
unknown_outputs = [
Expand All @@ -382,8 +385,7 @@ def _pallas_call_partial_eval(
source = source_info_util.current().replace(name_stack=name_stack)
unknown_params = dict(
jaxpr=jaxpr_unknown,
in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals),
*unknown_in_shapes),
in_shapes=(*res_shapes, *unknown_in_shapes),
out_shapes=tuple(unknown_out_shapes),
grid_spec=unknown_grid_spec,
which_linear=(*res_which_linear, *unknown_which_linear),
Expand All @@ -399,40 +401,6 @@ def _pallas_call_partial_eval(
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval

def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool]
) -> jax_core.Jaxpr:
num_inputs = len(which_linear)
num_outputs = len(jaxpr.invars) - num_inputs
def trans(*args):
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.
res_jaxpr, tangent_jaxpr_, *_ = \
pe.partial_eval_jaxpr_custom(jaxpr,
in_unknowns=[*which_linear, *[True] *
num_outputs],
in_inst=[*which_linear, *[True] *
num_outputs],
ensure_out_inst=[],
ensure_out_unknowns=[],
saveable=_save_everything)
res_args = [x for x, lin in zip(args, which_linear) if not lin]
res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args)

# Now that we have residual values, we run the tangent jaxpr. It takes as
# input the residuals, and all the refs (at least, the ones
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
# unknown inputs!
breakpoint()
primals_args = [*(r for u, r in zip(used_res, res) if u)]
ct_args = [x for x, u in zip(args, used_ct) if u]
ad.backward_pass(
tangent_jaxpr, (), False, (), (*res, *ct_args), ())
breakpoint()
return []
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans

def _pallas_call_transpose_rule(cts_in, *args,
jaxpr: jax_core.Jaxpr,
name: str,
Expand Down Expand Up @@ -645,6 +613,12 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
name = compilation_result.name
asm = compilation_result.asm
shared_mem = compilation_result.shared_mem
ref_effects = state.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)
is_accum = [
all(isinstance(eff, state.AccumEffect) for eff in ref_effect)
for ref_effect in ref_effects
]
if debug:
print(jaxpr)
print(grid_spec)
Expand All @@ -664,7 +638,9 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
# All arguments are buffers.
all_args = [None] * (len(in_shapes) + len(out_shapes))
kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel, grid[0], grid[1], grid[2], all_args
kernel, grid[0], grid[1], grid[2], all_args,
is_accum,
[s.size for s in [*in_shapes, *out_shapes]]
)

ctx.module_context.add_keepalive(kernel_call)
Expand Down Expand Up @@ -735,6 +711,32 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec],
return arg_shape
return tuple(s for s in block_spec.block_shape if s is not None)

def _pallas_call_bind(*args,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_spec: GridSpec,
**compiler_params: Any):
num_inputs = len(in_shapes)
num_outputs = len(out_shapes)
assert len(jaxpr.invars) == num_inputs + num_outputs, (len(jaxpr.invars),
num_inputs,
num_outputs)
assert len(grid_spec.block_mappings) == len(jaxpr.invars)
return jax_core.Primitive.bind(
pallas_call_p, *args,
jaxpr=jaxpr, name=name, in_shapes=in_shapes,
out_shapes=out_shapes, which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_spec=grid_spec, **compiler_params)
pallas_call_p.def_custom_bind(_pallas_call_bind)

def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
grid: Optional[Grid] = None,
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
Expand Down
20 changes: 17 additions & 3 deletions lib/triton_kernel_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,23 @@ class TritonKernelCall : public TritonKernelCallBase {
public:
TritonKernelCall(TritonKernel& kernel, uint32_t grid_0, uint32_t grid_1,
uint32_t grid_2,
std::vector<std::optional<uint64_t>> parameters)
std::vector<std::optional<uint64_t>> parameters,
std::vector<bool> zero_out,
std::vector<uint64_t> sizes)
: kernel_(kernel),
grid_{grid_0, grid_1, grid_2},
parameters_(std::move(parameters)) {}
parameters_(std::move(parameters)),
zero_out_(std::move(zero_out)),
sizes_(sizes) {}

void Launch(CUstream stream, void** buffers) override final {
for (int i = 0; i < sizes_.size(); ++i) {
bool should_zero = zero_out_[i];
if (should_zero) {
uint64_t size = sizes_[i];
CHECK_CUDA(cuMemsetD8Async((CUdeviceptr) (buffers[i]), 0, size * 4, stream));
}
}
std::vector<void*> params;
params.reserve(parameters_.size());
for (std::optional<uint64_t>& param : parameters_) {
Expand All @@ -169,6 +180,8 @@ class TritonKernelCall : public TritonKernelCallBase {
uint32_t grid_[3];
// Parameter values. `nullopt` values represent buffer arguments.
std::vector<std::optional<uint64_t>> parameters_;
std::vector<bool> zero_out_;
std::vector<uint64_t> sizes_;
};

class TritonAutotunedKernelCall : public TritonKernelCallBase {
Expand Down Expand Up @@ -320,7 +333,8 @@ PYBIND11_MODULE(triton_kernel_call_lib, m) {

py::class_<TritonKernelCall>(m, "TritonKernelCall")
.def(py::init<TritonKernel&, uint32_t, uint32_t, uint32_t,
std::vector<std::optional<uint64_t>>>(),
std::vector<std::optional<uint64_t>>, std::vector<bool>,
std::vector<uint64_t>>(),
py::keep_alive<1, 2>()) // Ensure that the kernel lives long enough.
.def_property_readonly("descriptor", [](TritonKernelCall& kernel_call) {
union {
Expand Down
25 changes: 20 additions & 5 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
# ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is
# updated
("tanh", jnp.tanh),
])
def test_jvp(self, impl):
@functools.partial(
Expand All @@ -560,8 +559,7 @@ def pallas_impl(x_ref, o_ref):
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
# ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is
# updated
("tanh", jnp.tanh),
])
def test_jvp_slice(self, impl):
@functools.partial(
Expand All @@ -584,7 +582,6 @@ def pallas_impl(x_ref, o_ref):
rtol=1e-5)
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2)

TODO(sharadmv): enable this when we update Triton
def test_jvp_matmul(self):
k1, k2 = random.split(random.PRNGKey(0))
x = random.normal(k1, (256, 128))
Expand All @@ -610,6 +607,24 @@ def add_vectors(x_ref, y_ref, o_ref):
out_ref = xy[0] + xy[1]
np.testing.assert_allclose(out, out_ref)

@parameterized.named_parameters(*[
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
("tanh", jnp.tanh),
])
def test_grad(self, impl):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32))
def pallas_impl(x_ref, o_ref):
o_ref[...] = impl(x_ref[...])

x = random.normal(random.PRNGKey(0))
g = jax.grad(pallas_impl)(x)
g_ref = jax.grad(impl)(x)
np.testing.assert_allclose(g, g_ref, atol=1e-5, rtol=1e-5)
jtu.check_grads(pallas_impl, (x,), modes=["rev"], order=1)


class PallasCallVmapTest(PallasTest):

Expand Down

0 comments on commit e05bd7c

Please sign in to comment.