Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A lot of boilerplate for TRITON_INTERPRET=1 without torch #5493

Open
stephen-huan opened this issue Dec 25, 2024 · 4 comments
Open

A lot of boilerplate for TRITON_INTERPRET=1 without torch #5493

stephen-huan opened this issue Dec 25, 2024 · 4 comments

Comments

@stephen-huan
Copy link
Contributor

Describe the bug

(This is more of a feature request than a bug, and not a very pressing one, so feel free to ignore.)

On the thread of triton-lang#204, it is possible to use triton-cpu with numpy/jax with the following Pointer shims

import jax.numpy as jnp
from jax import Array

import triton
import triton.language as tl


class Pointer:

    def __init__(self, data: Array) -> None:
        self.data = data
        self.dtype = data.dtype

    def data_ptr(self) -> int:
        return self.data.unsafe_buffer_pointer()


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


if __name__ == "__main__":
    x = jnp.ones(10)
    output = jnp.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)
import numpy as np

import triton
import triton.language as tl


class Pointer:

    def __init__(self, data: np.ndarray) -> None:
        self.data = data
        self.dtype = data.dtype

    def data_ptr(self) -> int:
        return self.data.ctypes.data


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


if __name__ == "__main__":
    x = np.ones(10)
    output = np.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)

(note that in the case of jax on gpu, it's possible to use jax-triton, see e.g. jax-ml/jax-triton#322 for an extension to cpu).

However, when TRITON_INTERPRET=1, the amount of boilerplate required drastically increases.

import os

os.environ["TRITON_INTERPRET"] = "1"


import jax
import jax.numpy as jnp
from jax import Array

import triton
import triton.language as tl


class Data:

    def __init__(self, data: Array) -> None:
        self.data = data

    def copy_(self, other: Array) -> None:
        self.data = other


class Pointer:

    def __init__(self, data: Array) -> None:
        self.data = Data(data)
        self.dtype = data.dtype
        self.ptr = data.unsafe_buffer_pointer()
        self.device = data.devices().pop()

    def data_ptr(self) -> int:
        return self.ptr

    def cpu(self) -> "Pointer":
        return self.to(jax.devices(backend="cpu")[0])

    def to(self, device) -> "Pointer":
        return Pointer(self.data.data.to_device(device))


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


if __name__ == "__main__":
    x = jnp.ones(10)
    output = jnp.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)

(this could probably be written more efficiently with jax.device_put and jax.device_get.)

import os

os.environ["TRITON_INTERPRET"] = "1"


import numpy as np

import triton
import triton.language as tl


class Data:

    def __init__(self, data: np.ndarray) -> None:
        self.data = data

    def copy_(self, other: np.ndarray) -> None:
        self.data = other


class Pointer:

    def __init__(self, data: np.ndarray) -> None:
        self.data = Data(data)
        self.dtype = data.dtype
        self.ptr = data.ctypes.data
        self.device = 0

    def data_ptr(self) -> int:
        return self.ptr

    def cpu(self) -> "Pointer":
        return self

    def to(self, device) -> "Pointer":
        return self


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


if __name__ == "__main__":
    x = np.ones(10)
    output = np.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)

This seems to be mostly a consequence of these lines in the interpreter.

def _init_args_hst(self, args_dev, kwargs):
args_hst = []
for arg in args_dev:
if hasattr(arg, "data_ptr"):
args_hst.append(arg.cpu())
else:
args_hst.append(arg)
# Process keyword arguments
kwargs_hst = {}
for key, value in kwargs.items():
if hasattr(value, "data_ptr"):
kwargs_hst[key] = value.cpu()
else:
kwargs_hst[key] = value
return args_hst, kwargs_hst
def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, "data_ptr"):
arg_dev.data.copy_(arg_hst.to(arg_dev.device).data)
# Restore keyword arguments
for key, kwarg_dev in kwargs.items():
kwarg_hst = kwargs_hst[key]
if hasattr(kwarg_dev, "data_ptr"):
kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)

It would be nice if the interpreter could support jax/numpy without all the boilerplate, especially because the interpreter lowers to numpy on cpu anyways. It would be extra nice if passing jax/numpy arrays "just worked" like pytorch tensors.

As I primarily write jax, this is not-so-relevant for me as jax has jax-triton and pallas (which has its own interpret mode). But given that (roughly) numpy : cpu :: pytorch : gpus, it would be nice if numpy was "blessed" for the cpu backend.

I would submit a PR, but it seems triton assumes things are torch tensor-like in all sorts of places in a much more global manner than #5490. Naively, it might be possible to simply add additional checks when the kernel is being executed (.data_ptr(), .unsafe_buffer_pointer(), .ctypes.data) but there's too much I don't understand about triton's organization (for example, what is TensorWrapper doing in jit.py and why does it have torch semantics?)

PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == CUDA_ERROR_INVALID_VALUE) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}} else if (status != CUDA_SUCCESS) {{
CUDA_CHECK(status); // Catch any other cuda API errors
ptr_info.valid = false;
}}
ptr_info.dev_ptr = dev_ptr;
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
ptr_info.valid = false;
return ptr_info;
}}

class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.data = base.data
self.device = base.device
self.shape = self.base.shape
def data_ptr(self):
return self.base.data_ptr()
def stride(self, i):
return self.base.stride(i)
def __str__(self) -> str:
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
def cpu(self):
return TensorWrapper(self.base.cpu(), self.dtype)
def copy_(self, other):
self.base.copy_(other.base)
def clone(self):
return TensorWrapper(self.base.clone(), self.dtype)
def to(self, device):
return TensorWrapper(self.base.to(device), self.dtype)
def new_empty(self, sizes):
return TensorWrapper(self.base.new_empty(sizes), self.dtype)

(originally filed as triton-lang#206.)

Environment details

triton-cpu: daa7eb0

@Jokeren
Copy link
Contributor

Jokeren commented Dec 25, 2024

triton assumes things are torch tensor-like in all sorts of places in a much more global manner

Yes. It is the assumption.

I would submit a PR,

Please do not submit a PR now.

Your topic is beyond simple changes in the interpreter as it considers generalizing the runtime support beyond torch. I would suggest investigating more thoroughly and propose some ideas for now if you really want to dig into it.

@Jokeren Jokeren added enhancement and removed bug labels Dec 25, 2024
@Jokeren
Copy link
Contributor

Jokeren commented Dec 25, 2024

Here are my two cents: a consensus has to be reached among Triton developers about whether this is something we want to pursue at this time. In the long term, I think it's a reasonable topic.
However, a roadmap needs to be planned, as this shouldn't be an ad-hoc fix. Instead, I envision a series of PRs being merged. Otherwise, there’s no guarantee that random changes won’t break some features, leading to reversion.

@stephen-huan
Copy link
Contributor Author

stephen-huan commented Dec 25, 2024

Your topic is beyond simple changes in the interpreter as it considers generalizing the runtime support beyond torch.

The interpreter is one thing, but would a PR supporting passing jax/numpy arrays to kernels like is done with pytorch tensors be acceptable? That is an inherently local change, since the kernel only needs to know the pointer/dtype at execution (that is, making the first two Pointer shims unnecessary in the issue) by adding extra logic in the drivers.

PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == CUDA_ERROR_INVALID_VALUE) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}} else if (status != CUDA_SUCCESS) {{
CUDA_CHECK(status); // Catch any other cuda API errors
ptr_info.valid = false;
}}
ptr_info.dev_ptr = dev_ptr;
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
ptr_info.valid = false;
return ptr_info;
}}

Otherwise, there’s no guarantee that random changes won’t break some features, leading to reversion.

This concern is reasonable. I do think it is relatively easy to test, essentially just run the current tests without torch. In my own testing I've uninstalled torch, but this could be done systematically in a test with some sort of runtime patching.

@Jokeren
Copy link
Contributor

Jokeren commented Dec 26, 2024

The interpreter is one thing, but would a PR supporting passing jax/numpy arrays to kernels like is done with pytorch tensors be acceptable?
This concern is reasonable. I do think it is relatively easy to test

It depends on how much extend you want to remove the torch dependency. Even though you can get around the issue to support JAX/numpy arrays, it doesn't mean that all utilities are compatible with JAX. The driver and the testing package still have a lot of dependencies on torch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants