Skip to content

Commit

Permalink
[pallas] Add support for custom call target name and serialized metad…
Browse files Browse the repository at this point in the history
…ata to Triton backend.

PiperOrigin-RevId: 549345856
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Jul 19, 2023
1 parent 695afcf commit 588ad6c
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
import functools
import operator
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
import zlib

import jax
Expand Down Expand Up @@ -1636,7 +1636,8 @@ def pallas_call_lowering(
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any
triton_params: Optional[Dict[str, Any]] = None,
**compiler_params: Any,
):
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
Expand Down Expand Up @@ -1702,20 +1703,23 @@ def pallas_call_lowering(
for shape in out_shapes
]

xc.register_custom_call_target(
name, triton_kernel_call_lib.get_custom_call(), platform="CUDA"
)

if triton_params is None:
triton_params = {}
serialized_metadata = triton_params.get("serialized_metadata", b"")

return jaxlib.hlo_helpers.custom_call(
call_target_name="triton_kernel_call",
call_target_name=name,
out_types=out_types,
operands=in_nodes,
backend_config=zlib.compress(kernel_call.to_proto(b"")),
backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
operand_layouts=triton_utils.avals_to_layouts(ctx.avals_in),
result_layouts=triton_utils.avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
)


mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda")
xc.register_custom_call_target(
"triton_kernel_call",
triton_kernel_call_lib.get_custom_call(),
platform="CUDA",
)

0 comments on commit 588ad6c

Please sign in to comment.