Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch dependency for importing triton, kernel execution and autotuning #204

Closed
stephen-huan opened this issue Dec 24, 2024 · 3 comments · May be fixed by #205
Closed

Torch dependency for importing triton, kernel execution and autotuning #204

stephen-huan opened this issue Dec 24, 2024 · 3 comments · May be fixed by #205
Labels
bug Something isn't working

Comments

@stephen-huan
Copy link

stephen-huan commented Dec 24, 2024

Describe the bug

Consider the following, adapted from 01-vector-add.py to use numpy instead of torch. On gpu triton depends on torch for a number of reasons that would be hard to replace (e.g. interfacing with cuda from python), but on cpu torch is a relatively heavy dependency just to make tensors, and numpy is strictly smaller (as torch depends on numpy).

import numpy as np

import triton
import triton.language as tl

rng = np.random.default_rng(0)


class Pointer:

    def __init__(self, data: np.ndarray) -> None:
        self.data = data
        self.dtype = data.dtype

    def data_ptr(self):
        return self.data.ctypes.data


@triton.autotune(
    configs=[
        triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
        triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
    ],
    key=["n_elements"],
)
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: np.ndarray, y: np.ndarray):
    output = np.empty_like(x)
    n_elements = output.size
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    add_kernel[grid](Pointer(x), Pointer(y), Pointer(output), n_elements)
    return output


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["size"],
        x_vals=[2**i for i in range(12, 28, 1)],
        x_log=True,
        line_arg="provider",
        line_vals=["triton", "numpy"],
        line_names=["Triton", "Numpy"],
        styles=[("blue", "-"), ("green", "-")],
        ylabel="GB/s",
        plot_name="vector-add-performance",
        args={},
    )
)
def benchmark(size, provider):
    x = rng.random(size, dtype=np.float32)
    y = rng.random(size, dtype=np.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == "numpy":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: add(x, y), quantiles=quantiles
        )
    gbps = lambda ms: 3 * x.size * x.itemsize * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
    size = 98432
    x = rng.random(size)
    y = rng.random(size)
    output_numpy = x + y
    output_triton = add(x, y)
    print(output_numpy)
    print(output_triton)
    print(
        f"The maximum difference between numpy and triton is "
        f"{np.max(np.abs(output_numpy - output_triton))}"
    )
    benchmark.run(print_data=True, show_plots=False)

Currently this errors when torch is not installed with

Traceback (most recent call last):
  File "...", line 19, in <module>
    @triton.autotune(
     ^^^^^^^^^^^^^^^^
  File ".../triton-cpu/python/triton/runtime/autotuner.py", line 361, in decorator
    return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../triton-cpu/python/triton/runtime/autotuner.py", line 127, in __init__
    self.do_bench = driver.active.get_benchmarker()
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../triton-cpu/python/triton/runtime/driver.py", line 33, in __getattr__
    self._initialize_obj()
  File ".../triton-cpu/python/triton/runtime/driver.py", line 30, in _initialize_obj
    self._obj = self._init_fn()
                ^^^^^^^^^^^^^^^
  File ".../triton-cpu/python/triton/runtime/driver.py", line 13, in _create_driver
    actives = [x.driver for x in backends.values() if x.driver.is_active()]
                                                      ^^^^^^^^^^^^^^^^^^^^
  File ".../triton-cpu/python/triton/backends/amd/driver.py", line 495, in is_active
    import torch
ModuleNotFoundError: No module named 'torch'

Environment details

triton-cpu: daa7eb0

@ienkovich
Copy link
Collaborator

All tutorials use Torch as a reference for both functionality and performance. We want to compare Triton's performance with the native Torch performance, not NumPy. So it's not just to make tensors, it's to give performance reference numbers. Also, it's preferrable to be able to run any tutorial on any device.

@stephen-huan
Copy link
Author

Sorry about the confusion. This issue just uses the tutorial as an illustration of the runtime dependency on torch,
and the associated PR was not suggesting to changing the tutorials but to remove this dependency on torch, but I agree that this was not obvious from reading the issue alone.

@ienkovich
Copy link
Collaborator

I see. In this case, it would be better to open an issue for each particular case when a dependency on Torch seems unreasonable. Please note that any related changes outside of the CPU backend (third_party/cpu) should go through the upstream repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants