From df95f362ad7139c3a6e9e06112459c20d46e24ae Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:38:24 +0800 Subject: [PATCH 1/2] upgrade triton dep --- jax_triton/triton_lib.py | 21 ++++++++++++++++----- pyproject.toml | 2 +- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 8e678d8..55ce478 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -156,12 +156,23 @@ def compile_ttir_to_ptx_inplace( compute_capability = triton_kernel_call_lib.get_compute_capability(device) if num_stages is None: num_stages = 3 if compute_capability >= 75 else 2 + # TODO (jon-chuang): handle the Hopper case of num_ctas > 1 + # (CTAs are Thread Block Clusters in NVIDIA speak) + num_ctas = 1 + + extra = { + 'cluster_info': _triton.ClusterInfo(), + 'enable_warp_specialization': False, + 'enable_persistent': False, + 'optimize_epilogue': False, + } if dump: print(ttir) try: - ttir = tc.optimize_ttir(ttir, compute_capability) - ttgir = tc.ttir_to_ttgir(ttir, num_warps) - ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability) + ttir = tc.optimize_ttir(ttir, arch=compute_capability) + ttgir = tc.ttir_to_ttgir(ttir, num_warps=num_warps, num_ctas=num_ctas, arch=compute_capability,) + ttgir = tc.optimize_ttgir(ttgir, + num_stages=num_stages, num_warps=num_warps, num_ctas=num_ctas, arch=compute_capability, **extra) except RuntimeError as e: ttir.dump() raise ValueError("TTIR->TTGIR pass failed!") from e @@ -169,14 +180,14 @@ def compile_ttir_to_ptx_inplace( print(ttgir) extern_libs = {} try: - llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability) + llir = tc.ttgir_to_llir(ttgir, extern_libs, arch=compute_capability, tma_infos=_triton.TMAInfos()) except RuntimeError as e: ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e shared_mem_bytes = _triton.get_shared_memory_size(ttgir) if dump: print(llir) - ptx = tc.llir_to_ptx(llir, compute_capability) + ptx = tc.llir_to_ptx(llir, arch=compute_capability) if dump: print(ptx) name = ptx_get_kernel_name(ptx) diff --git a/pyproject.toml b/pyproject.toml index bb7fc13..b3a7d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9,<3.11" dependencies = [ "absl-py>=1.4.0", "jax @ git+https://github.com/google/jax@a0c1265bbae2c3ec644d6181f23264b4794e9eac", - "triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230822000928/triton_nightly-2.1.0.dev20230822000928-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=39f23718220984746c7fc5831d65c2c990eaf2d755a1c16bcba24946515ef0f6" ] [project.optional-dependencies] From 04f5ae00a13053e755af70d5c446dabfee4795a6 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Thu, 21 Sep 2023 22:00:44 -0400 Subject: [PATCH 2/2] minor --- jax_triton/triton_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 55ce478..846bc7a 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -157,7 +157,7 @@ def compile_ttir_to_ptx_inplace( if num_stages is None: num_stages = 3 if compute_capability >= 75 else 2 # TODO (jon-chuang): handle the Hopper case of num_ctas > 1 - # (CTAs are Thread Block Clusters in NVIDIA speak) + # (CTAs > 1 in Triton involve Thread Block Clusters only available on Hopper) num_ctas = 1 extra = {