Skip to content

Commit

Permalink
Merge pull request #281 from rahulbatra85:rocm_updates
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668931126
  • Loading branch information
The jax_triton Authors committed Aug 29, 2024
2 parents feb3fc3 + 6742187 commit 12b8c8e
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 34 deletions.
167 changes: 136 additions & 31 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import copy
import dataclasses
import importlib.util
import functools
import os
import pprint
Expand All @@ -41,7 +42,8 @@
import jax.numpy as jnp
import numpy as np

CAN_USE_TRITON = False
CAN_USE_TRITON = importlib.util.find_spec("triton") is not None

try:
import triton
from triton.compiler import code_generator as code_gen
Expand All @@ -51,8 +53,7 @@
import triton._C.libtriton as _triton
from triton._C.libtriton import ir as tl_ir
import triton.backends.nvidia.compiler as cb

CAN_USE_TRITON = True
import triton.backends.amd.compiler as hb
except ModuleNotFoundError:
pass
try:
Expand Down Expand Up @@ -90,6 +91,14 @@
}


def is_device_rocm():
return "rocm" in jax.lib.xla_bridge.get_backend().platform_version


def is_device_cuda():
return "cuda" in jax.lib.xla_bridge.get_backend().platform_version


Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]

Expand Down Expand Up @@ -157,21 +166,46 @@ def aval_size_bytes(aval):


@dataclasses.dataclass
class PtxCompilationResult:
ptx: str
class CompilationResult:
binary: str
name: str
shared_mem_bytes: int
cluster_dims: tuple
ttgir: Optional[str]
llir: Optional[str]


def compile_ttir_inplace(
ttir,
backend: cb.CUDABackend | hb.HIPBackend,
options: cb.CUDAOptions | hb.HIPOptions,
compute_capability,
):
if is_device_cuda():
return compile_ttir_to_ptx_inplace(
ttir,
backend,
options,
compute_capability,
)

elif is_device_rocm():
return compile_ttir_to_hsaco_inplace(
ttir,
backend,
options,
compute_capability,
)
else:
raise RuntimeError("Unsupported device")


def compile_ttir_to_ptx_inplace(
ttir,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
compute_capability,
) -> PtxCompilationResult:
) -> CompilationResult:
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
Expand All @@ -188,7 +222,7 @@ def compile_ttir_to_ptx_inplace(
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = dict()
metadata = {}
opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)
ttgir = cuda_backend.make_ttgir(
opt_ttir,
Expand Down Expand Up @@ -226,8 +260,73 @@ def compile_ttir_to_ptx_inplace(
cluster_dims = metadata["cluster_dims"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
return PtxCompilationResult(
ptx=ptx,
return CompilationResult(
binary=ptx,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=ttgir,
llir=llir,
)


def compile_ttir_to_hsaco_inplace(
ttir,
hip_backend: hb.HIPBackend,
hip_options: hb.HIPOptions,
compute_capability,
) -> CompilationResult:
if hip_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
context = _triton.ir.context()
_triton.ir.load_dialects(context)
hip_backend.load_dialects(context)

# Triton compilation APIs only accept Triton-specific MLIR wrappers.
# So, here we serialize an ir.Module to a file and then deserialize
# it as a tl_ir.module.
with tempfile.NamedTemporaryFile(mode="wb") as f:
ttir.operation.write_bytecode(f)
f.flush()
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = {}
opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options)
ttgir = hip_backend.make_ttgir(opt_ttir, metadata, hip_options)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if hip_options.debug:
print(ttgir)
try:
llir = hip_backend.make_llir(ttgir, metadata, hip_options)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = metadata["shared"]
if hip_options.debug:
print(llir)

amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options)
hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options)

name = metadata["name"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
# cluster dims are NOT useful on hip backend.
# We just fill up with some value for API compatibility
cluster_dims = (0, 0, 0)
# Instead of passing hsaco which are "bytes", we first write
# to a file and then pass the "string" path. This is needed because
# nanobind doesn't automatically convert between bytes and string.
# https://github.com/wjakob/nanobind/discussions/137
fd, hsaco_path = tempfile.mkstemp()
with os.fdopen(fd, "wb") as f:
f.write(hsaco)
return CompilationResult(
binary=hsaco_path,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
Expand Down Expand Up @@ -296,29 +395,38 @@ def get_or_create_triton_kernel(
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)

if kernel is None:
target = cb.GPUTarget('cuda', compute_capability, 32)
cuda_backend = cb.CUDABackend(target)
cuda_options = cuda_backend.parse_options(
dict(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
optimize_epilogue=False,
debug=dump,
enable_fp_fusion=enable_fp_fusion,
)
)
opts = {
"num_warps": num_warps,
"num_stages": num_stages,
"num_ctas": num_ctas,
"optimize_epilogue": False,
"debug": dump,
"enable_fp_fusion": enable_fp_fusion,
}
if is_device_cuda():
target = cb.GPUTarget("cuda", compute_capability, 32)
backend = cb.CUDABackend(target)
options = backend.parse_options(opts)
elif is_device_rocm():
arch = triton_kernel_call_lib.get_arch_details(device)
arch = arch.split(":")[0]
target = hb.GPUTarget("hip", arch, 64)
backend = hb.HIPBackend(target)
options = backend.parse_options(opts)
else:
raise ValueError("Unsupported device.")

kernel_hash = abs(hash(cache_key))
if _JAX_TRITON_DUMP_DIR:
os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}")
with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f:
pprint.pprint(cache_key, stream=f)
pprint.pprint(cuda_options, stream=f)
pprint.pprint(options, stream=f)

context = _triton.ir.context()
_triton.ir.load_dialects(context)
cuda_backend.load_dialects(context)
codegen_fns = cuda_backend.get_codegen_implementation()
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()

module = code_gen.ast_to_ttir(
fn,
Expand All @@ -328,17 +436,14 @@ def get_or_create_triton_kernel(
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
options=options,
codegen_fns=codegen_fns,
context=context,
)
ttir = str(module)

compilation_result = compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
compute_capability,
compilation_result = compile_ttir_inplace(
module, backend, options, compute_capability
)
kernel_name = compilation_result.name
if _JAX_TRITON_DUMP_DIR:
Expand Down Expand Up @@ -372,7 +477,7 @@ def get_or_create_triton_kernel(
kernel_name,
num_warps,
compilation_result.shared_mem_bytes,
compilation_result.ptx,
compilation_result.binary,
ttir,
compute_capability,
*compilation_result.cluster_dims,
Expand Down
6 changes: 3 additions & 3 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,16 @@ def test_kernel_cache_equivalent_kernels(self):
x1, y1 = create_random_inputs([42])
x2, y2 = create_random_inputs([43])

compile_ttir_to_ptx_inplace = jt.triton_lib.compile_ttir_to_ptx_inplace
compile_ttir_inplace = jt.triton_lib.compile_ttir_inplace

call_count = [0]

def my_compile(*args, **kwargs):
call_count[0] += 1
return compile_ttir_to_ptx_inplace(*args, **kwargs)
return compile_ttir_inplace(*args, **kwargs)

with mock.patch.object(
jt.triton_lib, "compile_ttir_to_ptx_inplace", new=my_compile
jt.triton_lib, "compile_ttir_inplace", new=my_compile
):
_ = fn1(x1, y1)
self.assertEqual(call_count[0], 1)
Expand Down

0 comments on commit 12b8c8e

Please sign in to comment.