-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fp8] Support fp8e4m3
in torch_xla
#8005
Comments
@apivovarov @amithrm from the compiler side, what missing components are you looking to see materialize re: this feature? |
Related PRs:
|
Hi @apivovarov @amithrm, can you share your use case around the fp8 variants that is not available in pytorch core? |
Hi Siyuan, Here’s an example of a PyTorch/XLA code snippet that works for the import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
dtype = torch.float8_e5m2
# dtype=torch.float8_e4m3 # does not exist as of now
# dtype=torch.float8_e3m4 # does not exist as of now
def foo(a, b):
return a @ b
a = torch.ones(4, 4, dtype=dtype).to(device)
b = torch.ones(4, 4, dtype=dtype).to(device)
y = foo(a, b)
xm.mark_step()
print(y) Currently Pytorch supports the following f8 types:
XLA supports the following f8 types:
float8 types description can be found in llvm/ADT/APFloat.h |
|
🚀 Feature
Please enable
fp8e4m3
in torch_xla. This feature is in flight in openxla: https://github.com/openxla/xla/pull/16585/filesToday, PyTorch doesn't support fp8e4m3 yet, only the
funz
variants are supported. @amithrm wants to see this feature as an alternative tofp8e4m3fn
.cc @amithrm @JackCaoG
The text was updated successfully, but these errors were encountered: