Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8] Support fp8e4m3 in torch_xla #8005

Open
miladm opened this issue Sep 13, 2024 · 5 comments
Open

[fp8] Support fp8e4m3 in torch_xla #8005

miladm opened this issue Sep 13, 2024 · 5 comments
Assignees

Comments

@miladm
Copy link
Collaborator

miladm commented Sep 13, 2024

🚀 Feature

Please enable fp8e4m3 in torch_xla. This feature is in flight in openxla: https://github.com/openxla/xla/pull/16585/files

Today, PyTorch doesn't support fp8e4m3 yet, only the funz variants are supported. @amithrm wants to see this feature as an alternative to fp8e4m3fn.

cc @amithrm @JackCaoG

@miladm
Copy link
Collaborator Author

miladm commented Sep 26, 2024

@apivovarov @amithrm from the compiler side, what missing components are you looking to see materialize re: this feature?

@apivovarov
Copy link
Contributor

apivovarov commented Oct 24, 2024

Related PRs:

  • LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
  • LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
  • StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
  • StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
  • ml_dtypes PR-161 Add float8_e4m3 (Merged)
  • ml_dtypes PR-171 Add float8_e3m4 (Merged)
  • XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Merged)
  • XLA PR-16585 Add support for float8_e4m3 and float8_e3m4 types (Merged)
  • JAX PR-23585 Add float8_e4m3 and float8_e3m4 types support (Approved, Pending XLA CompatibilityRequirement::WEEK_12 to switch to StableHLO 1.7.0 - scheduled on or after Nov 28, 2024)

@lsy323
Copy link
Collaborator

lsy323 commented Oct 24, 2024

Hi @apivovarov @amithrm, can you share your use case around the fp8 variants that is not available in pytorch core?

@apivovarov
Copy link
Contributor

apivovarov commented Oct 24, 2024

Hi Siyuan,

Here’s an example of a PyTorch/XLA code snippet that works for the float8_e5m2 type but fails for float8_e4m3 due to the latter type not existing in the PyTorch module (as of now).

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

dtype = torch.float8_e5m2
# dtype=torch.float8_e4m3   # does not exist as of now
# dtype=torch.float8_e3m4   # does not exist as of now


def foo(a, b):
  return a @ b


a = torch.ones(4, 4, dtype=dtype).to(device)
b = torch.ones(4, 4, dtype=dtype).to(device)

y = foo(a, b)
xm.mark_step()

print(y)

Currently Pytorch supports the following f8 types:

  • torch.float8_e4m3fn
  • torch.float8_e4m3fnuz
  • torch.float8_e5m2
  • torch.float8_e5m2fnuz

XLA supports the following f8 types:

  • float8_e3m4 (supported by aws trn1 instance, missing in Pytorch)
  • float8_e4m3 (supported by aws trn1 instance, missing in Pytorch)
  • float8_e4m3fn
  • float8_e4m3fnuz
  • float8_e4m3b11fnuz
  • float8_e5m2 (supported by aws trn1 instance, exists in Pytorch)
  • float8_e5m2fnuz

float8 types description can be found in llvm/ADT/APFloat.h

@miladm
Copy link
Collaborator Author

miladm commented Oct 25, 2024

@lsy323

  • Do we have a pytorch bug to request dtype support and tie to this bug? Let's create clarity if PyTorch plans to offer the support in 2.6 release please.
  • As discussed offline, this use case may have some difference with that of INT4, making this dtype enablement a different challenge; question: if we support this new dtype on select ops that @apivovarov would approve, can we make progress without the PyTorch support?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants