Skip to content

Commit

Permalink
Integrate Triton up to [bfb8e413](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
gflegar authored and The jax_triton Authors committed Mar 11, 2024
1 parent 08f1063 commit 93ff85f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/triton_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice


@triton.jit
Expand Down Expand Up @@ -71,7 +72,7 @@ def tanh_kernel(
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
output = tl.math.tanh(x)
output = libdevice.tanh(x)
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)

Expand Down

0 comments on commit 93ff85f

Please sign in to comment.