Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce runtime dependency on torch #5490

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

stephen-huan
Copy link
Contributor

Currently, torch is required for importing triton and performing autotuning. This seems like a relatively heavy runtime dependency in the context of the cpu backend, as numpy can easily be used instead.

Opening here as suggested in triton-lang#205 to minimize future merge conflicts.

Ideally there would be a test for this, but with the cpu backend out-of-tree this seems hard to test.

See also triton-lang#204, triton-lang#205.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because not (currently) easy to test and basic functionality should be covered by existing tests.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@Jokeren
Copy link
Contributor

Jokeren commented Dec 25, 2024

As I have mentioned, what you submit now would be ad-hoc. I don't suggest creating PRs like this at this moment.

@stephen-huan
Copy link
Contributor Author

Assuming the amd backend works without torch, would this PR be acceptable if a test was added that removes torch for the amd backend and sees that the execution / autotuning still works as expected?

As I have mentioned, what you submit now would be ad-hoc.

I agree, but I think ad-hoc changes are only bad if they're frequently subject to change or impossible to test. I think this change can be tested (and can write tests if necessary), and is morally similar to removing the circular dependencies with torch by moving all global torch imports to local (which as far as I know is also not explicitly tested and simply fixed whenever it occurs). The amount of changes in this PR is relatively small (relative to the number of torch imports in the code overall), so I doubt that this PR will change substantially with future changes.

@@ -110,7 +112,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str
"""
assert return_mode in ["min", "max", "mean", "median", "all"]
import torch
import numpy as np

di = runtime.driver.active.get_device_interface()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really remove the torch dependency.

get_device_interface still imports troch

Copy link
Contributor Author

@stephen-huan stephen-huan Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on the driver. For cpu, it doesn't need torch, since there's currently only one device (the host cpu).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. I thought we are still going to import torch for the CPU backend anyway.

@minjang can you confirm that you want the CPU runtime to be completely independent of torch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jokeren No, triton-cpu doesn't have a plan or need to reduce/remove torch dependency.

@stephen-huan, Okay, I understand your intention. It's okay to reduce the dependency only for third_party/cpu. But, as you already had to change python/triton/testing.py, you will need to change many parts of the code outside third_party/cpu. And, due to (painful) rebasing and resolving merge conflicts, triton-cpu strongly wants to avoid such code changes. For example, test_core.py has heavy mixed usages of torch and numpy. Even if this is for testing, we still anyhow have both torch and numpy dependencies. So, right now, I agree with @Jokeren.

@@ -43,6 +44,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
:type return_mode: str
"""
import torch
Copy link
Contributor

@Jokeren Jokeren Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch is still here.

I would suggest you think more about what you actually want to achieve.

Is it really removing all torch dependencies or just making triton-cpu works better? The latter is probably more controllable and much easier.

Copy link
Contributor Author

@stephen-huan stephen-huan Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally this was a triton-cpu PR, but they told me to upstream everything that wasn't cpu backend specific. See triton-lang#205. So the goal of this PR is to just (1) import triton (2) execute kernels and (3) autotune with the cpu backend without torch. Of course, removing torch entirely from nvidia/amd is currently out of scope because torch is used as a convenient gpu library from python.

Copy link
Contributor Author

@stephen-huan stephen-huan Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, there seems to be some confusion around #5493, which I submitted around the same time as this PR. #5493 is more of a feature request/tracking issue where I propose (1) executing kernels on jax/numpy arrays directly, without needing a Pointer shim and (2) having the interpreter work on jax/numpy arrays without the Data + Pointer shims. This PR addresses neither of these concerns.

The goal of this PR is simply to be able to use triton on the cpu backend without importing torch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I'm very supportive about improving triton-cpu compatibility.

If this is the case, why there's torch->numpy replacement in this PR? Are these files not working in the triton-cpu repo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, they work if torch is imported. But without torch, of course they don't work. And this doesn't meaningfully change anything for the gpu backends, since the statistics computations are done on cpu anyways and the numpy methods have (roughly) the same semantics as the torch methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With torch->numpy replacement, this package ends up having mixed use of torch and numpy. Seems to me it's in a middle state that does not completely address the problem if triton-cpu wants to be independent of torch.

Thus, I'm not quite sure why the changes are necessary.

Let's wait for triton-cpu maintainers to get involved. We probably need more context for further discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants