-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A lot of boilerplate for TRITON_INTERPRET=1
without torch
#5493
Comments
Yes. It is the assumption.
Please do not submit a PR now. Your topic is beyond simple changes in the interpreter as it considers generalizing the runtime support beyond torch. I would suggest investigating more thoroughly and propose some ideas for now if you really want to dig into it. |
Here are my two cents: a consensus has to be reached among Triton developers about whether this is something we want to pursue at this time. In the long term, I think it's a reasonable topic. |
The interpreter is one thing, but would a PR supporting passing jax/numpy arrays to kernels like is done with pytorch tensors be acceptable? That is an inherently local change, since the kernel only needs to know the pointer/dtype at execution (that is, making the first two triton/third_party/nvidia/backend/driver.py Lines 311 to 342 in c1166e5
This concern is reasonable. I do think it is relatively easy to test, essentially just run the current tests without torch. In my own testing I've uninstalled torch, but this could be done systematically in a test with some sort of runtime patching. |
It depends on how much extend you want to remove the torch dependency. Even though you can get around the issue to support JAX/numpy arrays, it doesn't mean that all utilities are compatible with JAX. The |
Describe the bug
(This is more of a feature request than a bug, and not a very pressing one, so feel free to ignore.)
On the thread of triton-lang#204, it is possible to use triton-cpu with numpy/jax with the following
Pointer
shims(note that in the case of jax on gpu, it's possible to use jax-triton, see e.g. jax-ml/jax-triton#322 for an extension to cpu).
However, when
TRITON_INTERPRET=1
, the amount of boilerplate required drastically increases.(this could probably be written more efficiently with
jax.device_put
andjax.device_get
.)This seems to be mostly a consequence of these lines in the interpreter.
triton/python/triton/runtime/interpreter.py
Lines 1053 to 1078 in c1166e5
It would be nice if the interpreter could support jax/numpy without all the boilerplate, especially because the interpreter lowers to numpy on cpu anyways. It would be extra nice if passing jax/numpy arrays "just worked" like pytorch tensors.
As I primarily write jax, this is not-so-relevant for me as jax has jax-triton and pallas (which has its own interpret mode). But given that (roughly) numpy : cpu :: pytorch : gpus, it would be nice if numpy was "blessed" for the cpu backend.
I would submit a PR, but it seems triton assumes things are torch tensor-like in all sorts of places in a much more global manner than #5490. Naively, it might be possible to simply add additional checks when the kernel is being executed (
.data_ptr()
,.unsafe_buffer_pointer()
,.ctypes.data
) but there's too much I don't understand about triton's organization (for example, what isTensorWrapper
doing injit.py
and why does it have torch semantics?)triton/third_party/nvidia/backend/driver.py
Lines 311 to 342 in c1166e5
triton/python/triton/runtime/jit.py
Lines 876 to 910 in c1166e5
(originally filed as triton-lang#206.)
Environment details
triton-cpu: daa7eb0
The text was updated successfully, but these errors were encountered: