-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce runtime dependency on torch #5490
base: main
Are you sure you want to change the base?
Conversation
As I have mentioned, what you submit now would be ad-hoc. I don't suggest creating PRs like this at this moment. |
Assuming the amd backend works without torch, would this PR be acceptable if a test was added that removes torch for the amd backend and sees that the execution / autotuning still works as expected?
I agree, but I think ad-hoc changes are only bad if they're frequently subject to change or impossible to test. I think this change can be tested (and can write tests if necessary), and is morally similar to removing the circular dependencies with torch by moving all global torch imports to local (which as far as I know is also not explicitly tested and simply fixed whenever it occurs). The amount of changes in this PR is relatively small (relative to the number of torch imports in the code overall), so I doubt that this PR will change substantially with future changes. |
@@ -110,7 +112,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m | |||
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str | |||
""" | |||
assert return_mode in ["min", "max", "mean", "median", "all"] | |||
import torch | |||
import numpy as np | |||
|
|||
di = runtime.driver.active.get_device_interface() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't really remove the torch dependency.
get_device_interface
still imports troch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends on the driver. For cpu, it doesn't need torch, since there's currently only one device (the host cpu).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm. I thought we are still going to import torch for the CPU backend anyway.
@minjang can you confirm that you want the CPU runtime to be completely independent of torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Jokeren No, triton-cpu doesn't have a plan or need to reduce/remove torch dependency.
@stephen-huan, Okay, I understand your intention. It's okay to reduce the dependency only for third_party/cpu
. But, as you already had to change python/triton/testing.py
, you will need to change many parts of the code outside third_party/cpu
. And, due to (painful) rebasing and resolving merge conflicts, triton-cpu strongly wants to avoid such code changes. For example, test_core.py
has heavy mixed usages of torch and numpy. Even if this is for testing, we still anyhow have both torch and numpy dependencies. So, right now, I agree with @Jokeren.
@@ -43,6 +44,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod | |||
:type return_mode: str | |||
""" | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch is still here.
I would suggest you think more about what you actually want to achieve.
Is it really removing all torch dependencies or just making triton-cpu works better? The latter is probably more controllable and much easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally this was a triton-cpu PR, but they told me to upstream everything that wasn't cpu backend specific. See triton-lang#205. So the goal of this PR is to just (1) import triton (2) execute kernels and (3) autotune with the cpu backend without torch. Of course, removing torch entirely from nvidia/amd is currently out of scope because torch is used as a convenient gpu library from python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, there seems to be some confusion around #5493, which I submitted around the same time as this PR. #5493 is more of a feature request/tracking issue where I propose (1) executing kernels on jax/numpy arrays directly, without needing a Pointer
shim and (2) having the interpreter work on jax/numpy arrays without the Data
+ Pointer
shims. This PR addresses neither of these concerns.
The goal of this PR is simply to be able to use triton on the cpu backend without importing torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I'm very supportive about improving triton-cpu compatibility.
If this is the case, why there's torch->numpy replacement in this PR? Are these files not working in the triton-cpu repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, they work if torch is imported. But without torch, of course they don't work. And this doesn't meaningfully change anything for the gpu backends, since the statistics computations are done on cpu anyways and the numpy methods have (roughly) the same semantics as the torch methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With torch->numpy replacement, this package ends up having mixed use of torch and numpy. Seems to me it's in a middle state that does not completely address the problem if triton-cpu wants to be independent of torch.
Thus, I'm not quite sure why the changes are necessary.
Let's wait for triton-cpu maintainers to get involved. We probably need more context for further discussion.
Currently, torch is required for importing triton and performing autotuning. This seems like a relatively heavy runtime dependency in the context of the cpu backend, as numpy can easily be used instead.
Opening here as suggested in triton-lang#205 to minimize future merge conflicts.
Ideally there would be a test for this, but with the cpu backend out-of-tree this seems hard to test.
See also triton-lang#204, triton-lang#205.
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsSelect one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)