diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index a923c88..fed5623 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -14,7 +14,6 @@ """Library for JAX-Triton integrations.""" import jaxlib -from jax._src.lib import gpu_triton from jax_triton import utils from jax_triton.triton_lib import triton_call from jax.experimental.pallas import cdiv @@ -23,8 +22,13 @@ from jax_triton.version import __version__ from jax_triton.version import __version_info__ -get_compute_capability = gpu_triton.get_compute_capability +if jaxlib.version.__version_info__ >= (0, 4, 25): + from jax._src.pallas import triton + get_compute_capability = triton.get_compute_capability + if jaxlib.version.__version_info__ >= (0, 4, 14): + from jax._src.lib import gpu_triton + get_compute_capability = gpu_triton.get_compute_capability try: get_serialized_metadata = gpu_triton.get_serialized_metadata except AttributeError: diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 78ffaa7..2c2234c 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -47,7 +47,6 @@ import triton.language as tl from triton.runtime import autotuner import triton._C.libtriton as _triton - from triton._C.libtriton import ir as tl_ir import triton.backends.nvidia.compiler as cb CAN_USE_TRITON = True @@ -172,7 +171,7 @@ def compile_ttir_to_ptx_inplace( with tempfile.NamedTemporaryFile(mode="wb") as f: ttir.operation.write_bytecode(f) f.flush() - ttir = tl_ir.parse_mlir_module(f.name, context) + ttir = _triton.ir.parse_mlir_module(f.name, context) ttir.context = context try: metadata = dict()