Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Round-to-zero float32 conversion not supported #58

Open
int3 opened this issue Jul 19, 2024 · 1 comment
Open

Round-to-zero float32 conversion not supported #58

int3 opened this issue Jul 19, 2024 · 1 comment

Comments

@int3
Copy link
Collaborator

int3 commented Jul 19, 2024

Small repro case:

import triton
import triton.language as tl
import torch


@triton.jit
def type_convert_triton(src, dst, rounding: tl.constexpr):
    range = tl.arange(0, 128)
    x = tl.load(src + range)
    y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding)
    tl.store(dst + range, y)


src = torch.zeros([128], dtype=torch.float32)
dst = torch.empty(src.shape, dtype=torch.float16)
type_convert_triton[(1, )](src, dst, rounding='rtz')
@ienkovich
Copy link
Collaborator

I tried to enable this one some time ago and I thought I hit a LLVM bug but it appeared that arith::TruncFOp lowering doesn't guarantee us a proper rounding mode. It generates llvm.experimental.constrained.fptrunc but this intrinsic doesn't control rounding, it only hints the compiler about runtime rounding settings to be used: llvm/llvm-project#96815
I think we should lower it directly to vcvtps2ph intrinsic calls with explicit rounding to make it work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants