Skip to content

Commit

Permalink
WIP AD
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed May 8, 2023
1 parent 02184d7 commit 1b22199
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 48 deletions.
179 changes: 136 additions & 43 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
jvp_which_linear = (*which_linear, *(True,) * len(tangents))
jvp_inshapes = (*in_shapes, *in_shapes)
_, nonzero_tangent_in_shapes = partition_list(nonzero_tangents, in_shapes)
jvp_inshapes = (*in_shapes, *nonzero_tangent_in_shapes)
jvp_outshapes = (*out_shapes, *out_shapes)
if input_output_aliases:
raise NotImplementedError("`input_output_aliases` jvp not supported.")
Expand All @@ -172,7 +173,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)])
logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)])
in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)])
new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms))
nonzero_in_bms, _ = partition_list(nonzero_tangents, in_bms)
new_bms = tuple((*in_bms, *nonzero_in_bms, *out_bms, *out_bms))
new_grid_spec = grid_spec.replace(block_mappings=new_bms)
jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs,
*logical_tangent_inputs,
Expand Down Expand Up @@ -291,12 +293,13 @@ def _pallas_call_partial_eval(
jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr,
in_inst=all_unknowns,
in_inst=True,
in_unknowns=all_unknowns,
ensure_out_unknowns=[],
ensure_out_inst=[],
saveable=_save_everything)
# # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
breakpoint()
# `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
# regular valued input/outputs. However, we'd like to bind these jaxprs to a
# `for`, which expects only `Ref` inputs and no output. We need to convert
# both of these jaxprs into ones that are compatible with `for`.
Expand Down Expand Up @@ -339,13 +342,13 @@ def _pallas_call_partial_eval(
for a in res_avals
]
res_block_mappings = [
BlockMapping((*[None] * len(grid), *a.shape), index_map)
BlockMapping((*[pallas_core.mapped] * len(grid), *a.shape), index_map)
for a, index_map in zip(res_avals, res_index_mappings)
]
known_grid_spec = GridSpec(grid, (*known_in_block_mappings,
*known_out_block_mappings,
*res_block_mappings),
grid_spec.mapped_dims)
mapped_dims)
unknown_grid_spec = GridSpec(grid, (*res_block_mappings,
*unknown_in_block_mappings,
*unknown_out_block_mappings),
Expand All @@ -362,7 +365,7 @@ def _pallas_call_partial_eval(
input_output_aliases=(),
which_linear=tuple(known_which_linear),
**compiler_params)
known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)])
known_outputs, residuals = split_list(known_out_and_res, [len(known_out_shapes)])
residuals = map(trace.new_instantiated_const, residuals)
unknown_inputs = [*residuals, *unknown_tracers]
unknown_outputs = [
Expand All @@ -373,8 +376,7 @@ def _pallas_call_partial_eval(
source = source_info_util.current().replace(name_stack=name_stack)
unknown_params = dict(
jaxpr=jaxpr_unknown,
in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals),
*unknown_in_shapes),
in_shapes=(*res_shapes, *unknown_in_shapes),
out_shapes=tuple(unknown_out_shapes),
grid_spec=unknown_grid_spec,
which_linear=(*res_which_linear, *unknown_which_linear),
Expand All @@ -390,40 +392,6 @@ def _pallas_call_partial_eval(
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval

def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool]
) -> jax_core.Jaxpr:
num_inputs = len(which_linear)
num_outputs = len(jaxpr.invars) - num_inputs
def trans(*args):
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.
res_jaxpr, tangent_jaxpr_, *_ = \
pe.partial_eval_jaxpr_custom(jaxpr,
in_unknowns=[*which_linear, *[True] *
num_outputs],
in_inst=[*which_linear, *[True] *
num_outputs],
ensure_out_inst=[],
ensure_out_unknowns=[],
saveable=_save_everything)
res_args = [x for x, lin in zip(args, which_linear) if not lin]
res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args)

# Now that we have residual values, we run the tangent jaxpr. It takes as
# input the residuals, and all the refs (at least, the ones
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
# unknown inputs!
breakpoint()
primals_args = [*(r for u, r in zip(used_res, res) if u)]
ct_args = [x for x, u in zip(args, used_ct) if u]
ad.backward_pass(
tangent_jaxpr, (), False, (), (*res, *ct_args), ())
breakpoint()
return []
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans

def _pallas_call_transpose_rule(cts_in, *args,
jaxpr: jax_core.Jaxpr,
name: str,
Expand Down Expand Up @@ -592,6 +560,105 @@ def _pallas_call_batching_rule(args, dims, *,
return out, (0,) * len(out)
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule

class TritonCompilationResult(NamedTuple):
name: str
asm: Dict[str, str]
shared_mem: int
lowering_result: lowering.TritonLoweringResult

@weakref_lru_cache
def _compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec,
name: str, num_warps: int, num_stages: int
) -> TritonCompilationResult:
lowering_result = lowering.lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name)
backend = tc.runtime.backend.CUDA
device = 0
name, asm, shared_mem = tc.code_gen.compile_ttir(backend, lowering_result.module, device,
num_warps, num_stages, {}, 0)
return TritonCompilationResult(name, asm, shared_mem, lowering_result)


def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_spec: GridSpec,
**compiler_params: Any):
if interpret:
return mlir.lower_fun(_pallas_call_impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_spec=grid_spec, **compiler_params)
num_warps = compiler_params.get("num_warps", 4)
num_stages = compiler_params.get("num_stages", 3)
compilation_result = _compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)),
grid_spec, name, num_warps, num_stages)
name = compilation_result.name
asm = compilation_result.asm
shared_mem = compilation_result.shared_mem
ref_effects = state.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)
is_accum = [
all(isinstance(eff, state.AccumEffect) for eff in ref_effect)
for ref_effect in ref_effects
]
if debug:
print(jaxpr)
print(grid_spec)
lowering_result = compilation_result.lowering_result
if debug:
lowering_result.module.print()
out_type = ir.TupleType.get_tuple([
ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype))
for out_shape in ctx.avals_out])
i32_type = ir.IntegerType.get_signless(32)

kernel = triton_kernel_call_lib.TritonKernel(
asm["cubin"], name, num_warps, shared_mem
)

grid = normalize_grid(compilation_result.lowering_result.grid, metaparams={})
# All arguments are buffers.
all_args = [None] * (len(in_shapes) + len(out_shapes))
kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel, grid[0], grid[1], grid[2], all_args,
is_accum,
[s.size for s in [*in_shapes, *out_shapes]]
)

ctx.module_context.add_keepalive(kernel_call)
output_operand_aliases = ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
output_tuple_indices=[output],
operand_index=input,
operand_tuple_indices=[])
for input, output in input_output_aliases
])
out = mhlo.CustomCallOp(
[out_type],
in_nodes,
call_target_name=ir.StringAttr.get("triton_kernel_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(kernel_call.descriptor),
api_version=ir.IntegerAttr.get(i32_type, 1),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
output_operand_aliases=output_operand_aliases,
)
results = [mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
for i in range(len(out_shapes))]
return results
mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda")

@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
Expand Down Expand Up @@ -633,6 +700,32 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec],
return arg_shape
return tuple(s for s in block_spec.block_shape if s is not None)

def _pallas_call_bind(*args,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_spec: GridSpec,
**compiler_params: Any):
num_inputs = len(in_shapes)
num_outputs = len(out_shapes)
assert len(jaxpr.invars) == num_inputs + num_outputs, (len(jaxpr.invars),
num_inputs,
num_outputs)
assert len(grid_spec.block_mappings) == len(jaxpr.invars)
return jax_core.Primitive.bind(
pallas_call_p, *args,
jaxpr=jaxpr, name=name, in_shapes=in_shapes,
out_shapes=out_shapes, which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_spec=grid_spec, **compiler_params)
pallas_call_p.def_custom_bind(_pallas_call_bind)

def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
grid: Optional[Grid] = None,
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
Expand Down
25 changes: 20 additions & 5 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
# ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is
# updated
("tanh", jnp.tanh),
])
def test_jvp(self, impl):
@functools.partial(
Expand All @@ -728,8 +727,7 @@ def pallas_impl(x_ref, o_ref):
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
# ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is
# updated
("tanh", jnp.tanh),
])
def test_jvp_slice(self, impl):
@functools.partial(
Expand All @@ -752,7 +750,6 @@ def pallas_impl(x_ref, o_ref):
rtol=1e-5)
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2)

TODO(sharadmv): enable this when we update Triton
def test_jvp_matmul(self):
k1, k2 = random.split(random.PRNGKey(0))
x = random.normal(k1, (256, 128))
Expand All @@ -778,6 +775,24 @@ def add_vectors(x_ref, y_ref, o_ref):
out_ref = xy[0] + xy[1]
np.testing.assert_allclose(out, out_ref)

@parameterized.named_parameters(*[
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
("tanh", jnp.tanh),
])
def test_grad(self, impl):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32))
def pallas_impl(x_ref, o_ref):
o_ref[...] = impl(x_ref[...])

x = random.normal(random.PRNGKey(0))
g = jax.grad(pallas_impl)(x)
g_ref = jax.grad(impl)(x)
np.testing.assert_allclose(g, g_ref, atol=1e-5, rtol=1e-5)
jtu.check_grads(pallas_impl, (x,), modes=["rev"], order=1)


class PallasCallVmapTest(PallasTest):

Expand Down

0 comments on commit 1b22199

Please sign in to comment.