diff --git a/third_party/py/ml_dtypes/workspace.bzl b/third_party/py/ml_dtypes/workspace.bzl index 51505bf3a1460d..fea50144ee1738 100644 --- a/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/py/ml_dtypes/workspace.bzl @@ -7,8 +7,8 @@ float8 varieties, and int4. load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - ML_DTYPES_COMMIT = "24084d9ed2c3d45bf83b7a9bff833aa185bf9172" - ML_DTYPES_SHA256 = "c916a3e6b3d9bdcb476f506fdbbecb6d5e9f21f82f221dfcb42b320b4e85e55a" + ML_DTYPES_COMMIT = "82f3a61d7cd80d607e08c7a3ecbc0dbb8fcfd0c7" + ML_DTYPES_SHA256 = "d06443d5423cf2a85ee49b5fc8fd0a2261a5fdbb267e8bd7f86535880144a658" tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", diff --git a/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl b/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl index 51505bf3a1460d..fea50144ee1738 100644 --- a/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl @@ -7,8 +7,8 @@ float8 varieties, and int4. load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - ML_DTYPES_COMMIT = "24084d9ed2c3d45bf83b7a9bff833aa185bf9172" - ML_DTYPES_SHA256 = "c916a3e6b3d9bdcb476f506fdbbecb6d5e9f21f82f221dfcb42b320b4e85e55a" + ML_DTYPES_COMMIT = "82f3a61d7cd80d607e08c7a3ecbc0dbb8fcfd0c7" + ML_DTYPES_SHA256 = "d06443d5423cf2a85ee49b5fc8fd0a2261a5fdbb267e8bd7f86535880144a658" tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", diff --git a/third_party/tsl/tools/def_file_filter/symbols_pybind.txt b/third_party/tsl/tools/def_file_filter/symbols_pybind.txt index 0bdf3111c133b9..1765ae5f1557b5 100644 --- a/third_party/tsl/tools/def_file_filter/symbols_pybind.txt +++ b/third_party/tsl/tools/def_file_filter/symbols_pybind.txt @@ -64,10 +64,12 @@ tsl::ml_dtypes::RegisterTypes tsl::ml_dtypes::GetBfloat16Dtype tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype tsl::ml_dtypes::GetFloat8E4m3fnDtype +tsl::ml_dtypes::GetFloat8E4m3Dtype tsl::ml_dtypes::GetFloat8E5m2Dtype tsl::ml_dtypes::GetBfloat16TypeNum tsl::ml_dtypes::GetFloat8E4m3b11fnuzTypeNum tsl::ml_dtypes::GetFloat8E4m3fnTypeNum +tsl::ml_dtypes::GetFloat8E4m3TypeNum tsl::ml_dtypes::GetFloat8E5m2TypeNum [//tensorflow/python:py_func_lib] # py_func diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index 916be8db4f6998..f237eb2af9a838 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_ #include "ml_dtypes/include/float8.h" // from @ml_dtypes -#include "ml_dtypes/include/intn.h" // from @ml_dtypes +#include "ml_dtypes/include/intn.h" // from @ml_dtypes namespace tsl { +using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; diff --git a/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/tsl/tsl/protobuf/dnn.proto index 695db935f6a0b4..8dcf74f34624ca 100644 --- a/third_party/tsl/tsl/protobuf/dnn.proto +++ b/third_party/tsl/tsl/protobuf/dnn.proto @@ -22,6 +22,7 @@ enum DataType { kF8E5M2FNUZ = 10; kF8E4M3FNUZ = 11; kInt64 = 12; + kF8E4M3 = 13; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 4d0fbf3732ff9a..347f2efb12beaf 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF8E4M3) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, LinspaceF8E4M3Fn) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); diff --git a/xla/client/lib/math.cc b/xla/client/lib/math.cc index e7b362ff221662..8ec6b3443a2b27 100644 --- a/xla/client/lib/math.cc +++ b/xla/client/lib/math.cc @@ -171,6 +171,7 @@ XlaOp IsNegZero(XlaOp operand) { return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: @@ -315,12 +316,14 @@ XlaOp Erfc(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32( - x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, - [](XlaOp x) { - return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), - ScalarLike(x, 1) - ErfImpl32Cephes(x)); - }); + return DoWithUpcastToF32(x, + {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + [](XlaOp x) { + return Select( + Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), + ScalarLike(x, 1) - ErfImpl32Cephes(x)); + }); }); } @@ -492,9 +495,10 @@ XlaOp ErfInv(XlaOp x) { if (shape.element_type() == F64) { return ErfInv64(x); } - return DoWithUpcastToF32( - x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, - [](XlaOp x) { return ErfInv32(x); }); + return DoWithUpcastToF32(x, + {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + [](XlaOp x) { return ErfInv32(x); }); }); } @@ -622,10 +626,10 @@ XlaOp Lgamma(XlaOp input) { // F16 and BF16 don't provide sufficient precision for intermediate results // here (although it's better than you might expect!), so do the // computations in F32. - return DoWithUpcastToF32( - input, - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, - do_it); + return DoWithUpcastToF32(input, + {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + do_it); }); } @@ -720,10 +724,10 @@ XlaOp Digamma(XlaOp input) { auto& b = *input.builder(); return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); - return DoWithUpcastToF32( - input, - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, - do_it); + return DoWithUpcastToF32(input, + {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + do_it); }); } @@ -978,8 +982,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1031,8 +1035,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index 5e7e082385aef5..280046dddecbee 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -130,6 +130,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "TOKEN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; + case XLA_FFI_DataType_F8E4M3: + return os << "F8E4M3"; case XLA_FFI_DataType_F8E4M3FN: return os << "F8E4M3FN"; case XLA_FFI_DataType_F8E4M3B11FNUZ: diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index 39fdfe66d8ed59..da6446456c46c1 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -194,6 +194,7 @@ typedef enum { XLA_FFI_DataType_C128 = 18, XLA_FFI_DataType_TOKEN = 17, XLA_FFI_DataType_F8E5M2 = 19, + XLA_FFI_DataType_F8E4M3 = 28, XLA_FFI_DataType_F8E4M3FN = 20, XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index 6990090a4cda8e..902de2b12d048d 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -70,6 +70,7 @@ enum class DataType : uint8_t { C128 = XLA_FFI_DataType_C128, TOKEN = XLA_FFI_DataType_TOKEN, F8E5M2 = XLA_FFI_DataType_F8E5M2, + F8E4M3 = XLA_FFI_DataType_F8E4M3, F8E4M3FN = XLA_FFI_DataType_F8E4M3FN, F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ, F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, @@ -95,6 +96,7 @@ inline constexpr DataType C64 = DataType::C64; inline constexpr DataType C128 = DataType::C128; inline constexpr DataType TOKEN = DataType::TOKEN; inline constexpr DataType F8E5M2 = DataType::F8E5M2; +inline constexpr DataType F8E4M3 = DataType::F8E4M3; inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN; inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; @@ -114,6 +116,7 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::S8: case DataType::U8: case DataType::F8E5M2: + case DataType::F8E4M3: case DataType::F8E4M3FN: case DataType::F8E4M3B11FNUZ: case DataType::F8E5M2FNUZ: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 27dcff09504d03..2b6cf9573919cc 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -126,6 +126,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ), encoded(DataType::F8E4M3B11FNUZ)); @@ -175,6 +176,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), + ByteWidth(DataType::F8E4M3)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN), ByteWidth(DataType::F8E4M3FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ), diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index f77ddb8950c7f5..971ff7e0266a9f 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -256,6 +256,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C128: case PrimitiveType::TOKEN: case PrimitiveType::F8E5M2: + case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: case PrimitiveType::F8E4M3B11FNUZ: case PrimitiveType::F8E5M2FNUZ: diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index e3c7a5411e0bc4..163badc2f94413 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -110,6 +110,59 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, 0x1.fffffffffffffp-127, 0x1.aaaaaaaaaaaaap-127)); +TEST(FPDistanceTest, F8E4M3Distance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e4m3(8.0), + tsl::float8_e4m3(8.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e4m3(8.0), + tsl::float8_e4m3(13)), + 5); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e4m3(8.0), + tsl::float8_e4m3(6.0)), + 4); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float8_e4m3(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float8_e4m3(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ( + CalculateDistanceInFloats( + std::numeric_limits::min(), tsl::float8_e4m3(0)), + 8); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ( + CalculateDistanceInFloats( + -std::numeric_limits::min(), tsl::float8_e4m3(0)), + 8); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 16); +} + TEST(FPDistanceTest, F8E4M3FNDistance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats( diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index b107b75447dbae..6ab2e2906d94ef 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -520,51 +520,51 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { absl::Status HandlePower(const HloInstruction* power) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - // Case 0: 1^x = 1 and x^0 = 1, regardless of X, see - // Branch Cuts for Complex Elementary Functions or Much Ado About - // Nothing's Sign Bit, W. Kahan, Section 10. - if (lhs_el == ElementwiseT(1) || rhs_el == ElementwiseT(0)) { - return static_cast(1); - } - // Case 1: - // 1. inf^(a + 0i) = inf, if a > 0. - // 2. inf^(a + 0i) = 0, if a < 0. - if constexpr (is_complex_v) { - auto is_positive_infinity = [](ElementwiseT c) { - return c.imag() == 0 && c.real() > 0 && std::isinf(c.real()); - }; - auto is_positive_real = [](ElementwiseT c) { - return c.real() > 0 && c.imag() == 0; - }; - auto is_negative_real = [](ElementwiseT c) { - return c.real() < 0 && c.imag() == 0; - }; - if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el)) { - return static_cast(lhs_el); - } - if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) { - return static_cast(0); - } - } - // Case 2: - // Fallback to pow. - if constexpr (std::is_same_v) { - return lhs_el || !rhs_el; - } else if constexpr (std::is_integral_v) { - if constexpr (std::is_signed_v) { - if (rhs_el < static_cast(0)) { + ElementWiseBinaryOp( + power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + // Case 0: 1^x = 1 and x^0 = 1, regardless of X, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + if (lhs_el == ElementwiseT(1) || rhs_el == ElementwiseT(0)) { + return static_cast(1); + } + // Case 1: + // 1. inf^(a + 0i) = inf, if a > 0. + // 2. inf^(a + 0i) = 0, if a < 0. + if constexpr (is_complex_v) { + auto is_positive_infinity = [](ElementwiseT c) { + return c.imag() == 0 && c.real() > 0 && std::isinf(c.real()); + }; + auto is_positive_real = [](ElementwiseT c) { + return c.real() > 0 && c.imag() == 0; + }; + auto is_negative_real = [](ElementwiseT c) { + return c.real() < 0 && c.imag() == 0; + }; + if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el)) { + return static_cast(lhs_el); + } + if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) { + return static_cast(0); + } + } + // Case 2: + // Fallback to pow. + if constexpr (std::is_same_v) { + return lhs_el || !rhs_el; + } else if constexpr (std::is_integral_v) { + if constexpr (std::is_signed_v) { + if (rhs_el < static_cast(0)) { + return static_cast( + lhs_el == static_cast(1) ? 1 : 0); + } + } return static_cast( - lhs_el == static_cast(1) ? 1 : 0); + IPow>(lhs_el, rhs_el)); + } else { + return static_cast(std::pow(lhs_el, rhs_el)); } - } - return static_cast( - IPow>(lhs_el, rhs_el)); - } else { - return static_cast(std::pow(lhs_el, rhs_el)); - } - })); + })); return absl::OkStatus(); } @@ -1743,6 +1743,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 7c97c210aa36a5..b0edfbe630665b 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -19,6 +19,7 @@ limitations under the License. namespace xla { template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; diff --git a/xla/literal.cc b/xla/literal.cc index c1026718435087..10e1ffcc6c54d6 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -91,12 +91,12 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() || - !proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() || - !proto.f8e4m3fnuzs().empty() || !proto.f16s().empty() || - !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || - proto.c64s_size() || proto.c128s_size() || proto.preds_size() || - proto.tuple_literals_size(); + !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || + !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || + !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || + proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || + proto.preds_size() || proto.tuple_literals_size(); } // Lazy getter for the interned scalar shape in static storage. We reuse this @@ -2258,6 +2258,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E4M3: + *proto->mutable_f8e4m3s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E4M3FN: *proto->mutable_f8e4m3fns() = std::string( reinterpret_cast(data().data()), @@ -2436,6 +2441,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E4M3: { + const std::string& s(proto.f8e4m3s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e4m3) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E4M3FN: { const std::string& s(proto.f8e4m3fns()); TF_RET_CHECK(data().size() * diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 893820780276fe..309f8c8f9ce602 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -25,6 +25,63 @@ limitations under the License. namespace xla { namespace { +TEST(LiteralComparisonTest, F8E4M3CompareNear_Equal) { + auto actual = LiteralUtil::CreateR0(tsl::float8_e4m3(8.0)); + auto expected = + LiteralUtil::CreateR0(tsl::float8_e4m3(8.0)); + TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); +} + +TEST(LiteralComparisonTest, F8E4M3CompareNear_NotEqual_1ulp) { + auto actual = LiteralUtil::CreateR0(tsl::float8_e4m3(8.0)); + auto expected = + LiteralUtil::CreateR0(tsl::float8_e4m3(9.0)); + auto error_spec = ErrorSpec(0.0, 0.0); + EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); + error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3; + error_spec.low_precision_fp_error_spec.within_n_values = 1; + EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); +} + +TEST(LiteralComparisonTest, F8E4M3CompareNear_NotEqual_4ulps) { + auto actual = LiteralUtil::CreateR0(tsl::float8_e4m3(8.0)); + auto expected = + LiteralUtil::CreateR0(tsl::float8_e4m3(12.0)); + auto error_spec = ErrorSpec(0.0, 0.0); + error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3; + error_spec.low_precision_fp_error_spec.within_n_values = 1; + EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); + error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3; + error_spec.low_precision_fp_error_spec.within_n_values = 4; + EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); +} + +TEST(LiteralComparisonTest, FloatUsingF8E4M3CompareNear_NotEqual_4ulps) { + auto actual = LiteralUtil::CreateR0(8.0); + auto expected = LiteralUtil::CreateR0(12.1); + auto error_spec = ErrorSpec(0.0, 0.0); + error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3; + error_spec.low_precision_fp_error_spec.within_n_values = 1; + EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); + error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3; + error_spec.low_precision_fp_error_spec.within_n_values = 4; + EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + /*detailed_message=*/false, + /*miscompare_callback=*/nullptr)); +} + TEST(LiteralComparisonTest, F8E4M3FNCompareNear_Equal) { auto actual = LiteralUtil::CreateR0(tsl::float8_e4m3fn(8.0)); diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 42b4340d2ddf82..9449d2a1091e03 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -173,8 +173,12 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { EXPECT_EQ("f8e5m2[] 3", f8e5m2_lit_truncated.ToString()); auto f8e4m3_lit = + LiteralUtil::CreateR0(tsl::float8_e4m3(0.5)); + EXPECT_EQ("f8e4m3[] 0.5", f8e4m3_lit.ToString()); + + auto f8e4m3fn_lit = LiteralUtil::CreateR0(tsl::float8_e4m3fn(0.5)); - EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString()); + EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3fn_lit.ToString()); auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0( tsl::float8_e4m3b11fnuz(0.5)); @@ -644,15 +648,19 @@ TEST_F(LiteralUtilTest, IsAll) { // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(9)); - tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 + tsl::float8_e4m3 e4m3(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({e4m3}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({e4m3}).IsAll(9)); + + tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3fn EXPECT_FALSE(LiteralUtil::CreateR1({r16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({r16}).IsAll(9)); - tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3 + tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3b11fnuz EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); - tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3 + tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3fnuz EXPECT_FALSE(LiteralUtil::CreateR1({t16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({t16}).IsAll(9)); @@ -1226,6 +1234,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F8e5m2) { } TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3) { + Literal output(ShapeUtil::MakeShape(F8E4M3, {3})); + tsl::float8_e4m3 x(0.5f); + output.PopulateWithValue(x); + auto expected = LiteralUtil::CreateR1({x, x, x}); + EXPECT_EQ(output, expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3fn) { Literal output(ShapeUtil::MakeShape(F8E4M3FN, {3})); tsl::float8_e4m3fn x(0.5f); output.PopulateWithValue(x); @@ -1747,9 +1763,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { using e5 = tsl::float8_e5m2; auto f8e5m2 = LiteralUtil::CreateR2WithLayout( {{e5{0.}, e5{1.}}, {e5{2.}, e5{3.}}}, layout_r2_dim0major_); - using e4 = tsl::float8_e4m3fn; + using e4 = tsl::float8_e4m3; auto f8e4m3 = LiteralUtil::CreateR2WithLayout( {{e4{0.}, e4{1.}}, {e4{2.}, e4{3.}}}, layout_r2_dim0major_); + using e4fn = tsl::float8_e4m3fn; + auto f8e4m3fn = LiteralUtil::CreateR2WithLayout( + {{e4fn{0.}, e4fn{1.}}, {e4fn{2.}, e4fn{3.}}}, layout_r2_dim0major_); using b11 = tsl::float8_e4m3b11fnuz; auto f8e4m3b11 = LiteralUtil::CreateR2WithLayout( {{b11{0.}, b11{1.}}, {b11{2.}, b11{3.}}}, layout_r2_dim0major_); @@ -1770,15 +1789,27 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { conv = f8e4m3.Convert(F8E5M2).value(); EXPECT_EQ(conv, f8e5m2); - conv = s8.Convert(F8E4M3FN).value(); + conv = s8.Convert(F8E4M3).value(); EXPECT_EQ(conv, f8e4m3); - conv = f32.Convert(F8E4M3FN).value(); + conv = f32.Convert(F8E4M3).value(); EXPECT_EQ(conv, f8e4m3); - conv = f8e5m2.Convert(F8E4M3FN).value(); + conv = f8e5m2.Convert(F8E4M3).value(); EXPECT_EQ(conv, f8e4m3); + conv = f8e4m3fn.Convert(F8E5M2).value(); + EXPECT_EQ(conv, f8e5m2); + + conv = s8.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3fn); + + conv = f32.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3fn); + + conv = f8e5m2.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3fn); + conv = f8e5m2.Convert(S8).value(); EXPECT_EQ(conv, s8); @@ -1797,6 +1828,15 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { conv = f8e4m3.Convert(C128).value(); EXPECT_EQ(conv, c128); + conv = f8e4m3fn.Convert(S8).value(); + EXPECT_EQ(conv, s8); + + conv = f8e4m3fn.Convert(F32).value(); + EXPECT_EQ(conv, f32); + + conv = f8e4m3fn.Convert(C128).value(); + EXPECT_EQ(conv, c128); + conv = f8e4m3b11.Convert(S8).value(); EXPECT_EQ(conv, s8); @@ -2254,9 +2294,12 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); - using e4 = tsl::float8_e4m3fn; + using e4 = tsl::float8_e4m3; auto vector_f8e4m3 = LiteralUtil::CreateR1({e4{10.0}, e4{20.0}, e4{-32.0}}); + using e4fn = tsl::float8_e4m3fn; + auto vector_f8e4m3fn = + LiteralUtil::CreateR1({e4fn{10.0}, e4fn{20.0}, e4fn{-32.0}}); using b11 = tsl::float8_e4m3b11fnuz; auto vector_f8e4m3b11 = LiteralUtil::CreateR1({b11{10.0}, b11{20.0}, b11{-30.0}}); @@ -2288,6 +2331,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); @@ -2575,6 +2619,14 @@ TEST_F(LiteralUtilTest, IsEqualAt) { tsl::float8_e4m3fnuz{val_double}); EXPECT_TRUE(c6.IsEqualAt({}, val_double)); EXPECT_TRUE(c6.IsEqualAt({}, val_integral)); + Literal c8 = + LiteralUtil::CreateR0(tsl::float8_e4m3{val_double}); + EXPECT_TRUE(c8.IsEqualAt({}, val_double)); + EXPECT_TRUE(c8.IsEqualAt({}, val_integral)); + Literal c9 = + LiteralUtil::CreateR0(tsl::float8_e4m3fn{val_double}); + EXPECT_TRUE(c9.IsEqualAt({}, val_double)); + EXPECT_TRUE(c9.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2900,10 +2952,10 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, - F8E5M2FNUZ, F8E4M3FNUZ, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/literal_util.cc b/xla/literal_util.cc index 745194cdc24b39..9f410b26716a9d 100644 --- a/xla/literal_util.cc +++ b/xla/literal_util.cc @@ -239,6 +239,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return ConvertType(bf16_literal); } +/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + +/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FN( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + /* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FNUZ( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); diff --git a/xla/literal_util.h b/xla/literal_util.h index a19ed6fb1e529e..1861668578bac1 100644 --- a/xla/literal_util.h +++ b/xla/literal_util.h @@ -241,6 +241,8 @@ class LiteralUtil { // recursively converts its elements. static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); + static Literal ConvertF32ToF8E4M3(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E4M3FN(const LiteralSlice& f32_literal); static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal); static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal); static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); diff --git a/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc index 7b3769f09ebe13..e7b613f47dcc88 100644 --- a/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc +++ b/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -280,6 +280,8 @@ absl::StatusOr LiteralToValue(const xla::Literal& literal) { return {{ArrayLiteralToTensor(literal)}}; case xla::F8E5M2: return absl::UnimplementedError("F8E5M2 not implemented"); + case xla::F8E4M3: + return absl::UnimplementedError("F8E4M3 not implemented"); case xla::F8E4M3FN: return absl::UnimplementedError("F8E4M3FN not implemented"); case xla::F8E4M3B11FNUZ: diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 59b19c34611412..7e293ce536f487 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -34,6 +34,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getI1Type(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); + case xla::PrimitiveType::F8E4M3: + return b.getFloat8E4M3Type(); case xla::PrimitiveType::F8E4M3FN: return b.getFloat8E4M3FNType(); case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -76,6 +78,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; + } else if (type.isFloat8E4M3()) { + return xla::PrimitiveType::F8E4M3; } else if (type.isFloat8E4M3FN()) { return xla::PrimitiveType::F8E4M3FN; } else if (type.isFloat8E4M3B11FNUZ()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index 6c19098574dec5..d8c8e1a7245e9a 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -102,6 +102,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, + {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, {F8E4M3B11FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3B11FNUZType(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 90965f06086831..4edcd8ed2a0ae8 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1814,6 +1814,13 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 65594f55fd979d..651054bc899892 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6832,6 +6832,13 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e4m3fn(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 0f2e1b108a710f..8897c4274e8506 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1787,6 +1787,13 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 96b71755b60e8a..db28cedb1b8d96 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -630,6 +630,7 @@ typedef enum { // Truncated 8 bit floating-point formats. PJRT_Buffer_Type_F8E5M2, + PJRT_Buffer_Type_F8E4M3, PJRT_Buffer_Type_F8E4M3FN, PJRT_Buffer_Type_F8E4M3B11FNUZ, PJRT_Buffer_Type_F8E5M2FNUZ, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index b9508cf24950b4..1ea432b7c200e5 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -295,6 +295,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; + case xla::PrimitiveType::F8E4M3: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3; case xla::PrimitiveType::F8E4M3FN: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN; case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -358,6 +360,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C128; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: + return xla::PrimitiveType::F8E4M3; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN: return xla::PrimitiveType::F8E4M3FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3B11FNUZ: diff --git a/xla/primitive_util.h b/xla/primitive_util.h index 8fbeedbff94dad..5feb48e5ec9b95 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -180,6 +180,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E4M3; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FN; @@ -309,6 +314,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e4m3; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fn; @@ -362,8 +372,8 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { } constexpr bool IsF8Type(PrimitiveType type) { - return type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ || - type == F8E5M2FNUZ || type == F8E4M3FNUZ; + return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || + type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ; } constexpr bool IsFloatingPointType(PrimitiveType type) { @@ -428,6 +438,9 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F8E4M3: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E4M3FN: return std::forward(f)( PrimitiveTypeConstant()); diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index e8c9dc77087062..eed2ee0f2159a2 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -76,6 +76,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][BF16] = true; expecteds[PRED][C128] = true; expecteds[PRED][F8E5M2] = true; + expecteds[PRED][F8E4M3] = true; expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = true; expecteds[PRED][F8E5M2FNUZ] = true; @@ -100,6 +101,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][BF16] = true; expecteds[S2][C128] = true; expecteds[S2][F8E5M2] = true; + expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; expecteds[S2][F8E4M3B11FNUZ] = true; expecteds[S2][F8E5M2FNUZ] = true; @@ -124,6 +126,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][BF16] = true; expecteds[S4][C128] = true; expecteds[S4][F8E5M2] = true; + expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; expecteds[S4][F8E4M3B11FNUZ] = true; expecteds[S4][F8E5M2FNUZ] = true; @@ -148,6 +151,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][BF16] = true; expecteds[S8][C128] = true; expecteds[S8][F8E5M2] = false; + expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; expecteds[S8][F8E4M3B11FNUZ] = false; expecteds[S8][F8E5M2FNUZ] = false; @@ -172,6 +176,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][BF16] = false; expecteds[S16][C128] = true; expecteds[S16][F8E5M2] = false; + expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; expecteds[S16][F8E4M3B11FNUZ] = false; expecteds[S16][F8E5M2FNUZ] = false; @@ -196,6 +201,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][BF16] = false; expecteds[S32][C128] = true; expecteds[S32][F8E5M2] = false; + expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; expecteds[S32][F8E4M3B11FNUZ] = false; expecteds[S32][F8E5M2FNUZ] = false; @@ -220,6 +226,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][BF16] = false; expecteds[S64][C128] = false; expecteds[S64][F8E5M2] = false; + expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; expecteds[S64][F8E4M3B11FNUZ] = false; expecteds[S64][F8E5M2FNUZ] = false; @@ -246,6 +253,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][BF16] = true; expecteds[U2][C128] = true; expecteds[U2][F8E5M2] = true; + expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; expecteds[U2][F8E4M3B11FNUZ] = true; expecteds[U2][F8E5M2FNUZ] = true; @@ -272,6 +280,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][BF16] = true; expecteds[U4][C128] = true; expecteds[U4][F8E5M2] = false; + expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; expecteds[U4][F8E4M3B11FNUZ] = true; expecteds[U4][F8E5M2FNUZ] = false; @@ -298,6 +307,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][BF16] = true; expecteds[U8][C128] = true; expecteds[U8][F8E5M2] = false; + expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; expecteds[U8][F8E4M3B11FNUZ] = false; expecteds[U8][F8E5M2FNUZ] = false; @@ -322,6 +332,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][BF16] = false; expecteds[U16][C128] = true; expecteds[U16][F8E5M2] = false; + expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; expecteds[U16][F8E4M3B11FNUZ] = false; expecteds[U16][F8E5M2FNUZ] = false; @@ -346,6 +357,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][BF16] = false; expecteds[U32][C128] = true; expecteds[U32][F8E5M2] = false; + expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; expecteds[U32][F8E4M3B11FNUZ] = false; expecteds[U32][F8E5M2FNUZ] = false; @@ -370,6 +382,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][BF16] = false; expecteds[U64][C128] = false; expecteds[U64][F8E5M2] = false; + expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; expecteds[U64][F8E4M3B11FNUZ] = false; expecteds[U64][F8E5M2FNUZ] = false; @@ -394,6 +407,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][BF16] = false; expecteds[F16][C128] = true; expecteds[F16][F8E5M2] = false; + expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; expecteds[F16][F8E4M3B11FNUZ] = false; expecteds[F16][F8E5M2FNUZ] = false; @@ -418,6 +432,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][BF16] = false; expecteds[F32][C128] = true; expecteds[F32][F8E5M2] = false; + expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; expecteds[F32][F8E4M3B11FNUZ] = false; expecteds[F32][F8E5M2FNUZ] = false; @@ -442,6 +457,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][BF16] = false; expecteds[F64][C128] = true; expecteds[F64][F8E5M2] = false; + expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; expecteds[F64][F8E4M3B11FNUZ] = false; expecteds[F64][F8E5M2FNUZ] = false; @@ -466,6 +482,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][BF16] = false; expecteds[C64][C128] = true; expecteds[C64][F8E5M2] = false; + expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; expecteds[C64][F8E4M3B11FNUZ] = false; expecteds[C64][F8E5M2FNUZ] = false; @@ -490,6 +507,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; expecteds[BF16][F8E5M2] = false; + expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; expecteds[BF16][F8E4M3B11FNUZ] = false; expecteds[BF16][F8E5M2FNUZ] = false; @@ -514,6 +532,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][BF16] = false; expecteds[C128][C128] = true; expecteds[C128][F8E5M2] = false; + expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; expecteds[C128][F8E4M3B11FNUZ] = false; expecteds[C128][F8E5M2FNUZ] = false; @@ -538,10 +557,36 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; expecteds[F8E5M2][F8E5M2] = true; + expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; expecteds[F8E5M2][F8E4M3B11FNUZ] = false; expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; + expecteds[F8E4M3][PRED] = false; + expecteds[F8E4M3][S2] = false; + expecteds[F8E4M3][S4] = false; + expecteds[F8E4M3][S8] = false; + expecteds[F8E4M3][S16] = false; + expecteds[F8E4M3][S32] = false; + expecteds[F8E4M3][S64] = false; + expecteds[F8E4M3][U2] = false; + expecteds[F8E4M3][U4] = false; + expecteds[F8E4M3][U8] = false; + expecteds[F8E4M3][U16] = false; + expecteds[F8E4M3][U32] = false; + expecteds[F8E4M3][U64] = false; + expecteds[F8E4M3][F16] = true; + expecteds[F8E4M3][F32] = true; + expecteds[F8E4M3][F64] = true; + expecteds[F8E4M3][C64] = true; + expecteds[F8E4M3][BF16] = true; + expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F8E5M2] = false; + expecteds[F8E4M3][F8E5M2FNUZ] = false; + expecteds[F8E4M3][F8E4M3] = true; + expecteds[F8E4M3][F8E4M3FN] = false; + expecteds[F8E4M3][F8E4M3FNUZ] = false; + expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S2] = false; expecteds[F8E4M3FN][S4] = false; @@ -562,7 +607,10 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; expecteds[F8E4M3FN][F8E5M2] = false; + expecteds[F8E4M3FN][F8E5M2FNUZ] = false; + expecteds[F8E4M3FN][F8E4M3] = false; expecteds[F8E4M3FN][F8E4M3FN] = true; + expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S2] = false; @@ -584,12 +632,11 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; + expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; expecteds[F8E4M3B11FNUZ][F8E4M3B11FNUZ] = true; expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S2] = false; expecteds[F8E5M2FNUZ][S4] = false; @@ -610,6 +657,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; expecteds[F8E5M2FNUZ][F8E5M2] = false; + expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; expecteds[F8E5M2FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; @@ -634,6 +682,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; expecteds[F8E4M3FNUZ][F8E5M2] = false; + expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; expecteds[F8E4M3FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index 1de5702b6cc8df..95c49455b1b94b 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -133,6 +133,7 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); + CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -175,6 +176,7 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); + CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index 06a92b67f863c8..525f5d261f178d 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -78,13 +78,14 @@ class DType { // dtype will have empty dimensions. kToken = 17, + kF8E4M3 = 28, kF8E4M3FN = 20, kF8E4M3B11FNUZ = 23, kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, - // Next = 26 + // Next = 29 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index eadfd42a3550cd..e96fa62537876f 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -60,6 +60,7 @@ message DTypeProto { // dtype will have empty dimensions. KIND_TOKEN = 17; + KIND_F8E4M3 = 28; KIND_F8E4M3FN = 20; KIND_F8E4M3B11FNUZ = 23; KIND_F8E4M3FNUZ = 25; diff --git a/xla/python/pjrt_ifrt/pjrt_array.cc b/xla/python/pjrt_ifrt/pjrt_array.cc index 751b00c9b37620..01700a4fc6dc22 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/xla/python/pjrt_ifrt/pjrt_array.cc @@ -148,6 +148,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); @@ -184,6 +185,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: + case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: case xla::PrimitiveType::F8E4M3B11FNUZ: case xla::PrimitiveType::F8E4M3FNUZ: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index c5d9051f1c603f..aeab43f5992dcf 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -34,7 +34,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/complex.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" @@ -184,6 +184,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E4M3FN; @@ -393,6 +396,7 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = @@ -582,6 +586,7 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index f3b8db6fdae018..c628edd51fac4f 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -31,9 +31,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" -#include "nanobind/ndarray.h" // IWYU pragma: keep -#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/ndarray.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "xla/layout.h" #include "xla/literal.h" @@ -59,6 +59,7 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + nb_dtype float8_e4m3; nb_dtype float8_e4m3fn; nb_dtype float8_e4m3b11fnuz; nb_dtype float8_e4m3fnuz; @@ -75,6 +76,7 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + dtypes->float8_e4m3 = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3")); dtypes->float8_e4m3fn = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fn")); dtypes->float8_e5m2 = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2")); @@ -140,6 +142,7 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + map->emplace(custom_dtypes.float8_e4m3, F8E4M3); map->emplace(custom_dtypes.float8_e4m3fn, F8E4M3FN); map->emplace(custom_dtypes.float8_e4m3b11fnuz, F8E4M3B11FNUZ); map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); @@ -204,6 +207,8 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F8E4M3: + return custom_dtypes.float8_e4m3; case F8E4M3FN: return custom_dtypes.float8_e4m3fn; case F8E4M3B11FNUZ: @@ -284,6 +289,8 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF8E4M3: + return custom_dtypes.float8_e4m3; case ifrt::DType::kF8E4M3FN: return custom_dtypes.float8_e4m3fn; case ifrt::DType::kF8E4M3B11FNUZ: @@ -347,6 +354,7 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + dtypes->np_float8_e4m3 = nb::object(ml_dtypes.attr("float8_e4m3")); dtypes->np_float8_e4m3fn = nb::object(ml_dtypes.attr("float8_e4m3fn")); dtypes->np_float8_e4m3b11fnuz = nb::object(ml_dtypes.attr("float8_e4m3b11fnuz")); diff --git a/xla/python/types.h b/xla/python/types.h index ed7ca847b1a7f7..1e849fe85eeb70 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -79,6 +79,7 @@ struct NumpyScalarTypes { nanobind::object np_uint32; nanobind::object np_uint64; nanobind::object np_bfloat16; + nanobind::object np_float8_e4m3; nanobind::object np_float8_e4m3fn; nanobind::object np_float8_e4m3b11fnuz; nanobind::object np_float8_e4m3fnuz; @@ -128,7 +129,6 @@ nanobind::tuple SpanToNbTuple(absl::Span xs) { // references to the objects. nanobind::tuple MutableSpanToNbTuple(absl::Span xs); - template std::vector IterableToVector(const nanobind::iterable& iterable) { std::vector output; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 2136e981507f10..dc277462b7e7ac 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -36,16 +36,16 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/nb_defs.h" -#include "nanobind/stl/function.h" // IWYU pragma: keep -#include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "nanobind/stl/set.h" // IWYU pragma: keep -#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep -#include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" @@ -71,8 +71,8 @@ limitations under the License. #elif defined(__APPLE__) #include "gloo/transport/uv/device.h" #include "xla/pjrt/cpu/gloo_collectives.h" // NOLINT -#include "xla/pjrt/cpu/gloo_kv_store.h" // NOLINT -#endif // defined(__linux__) +#include "xla/pjrt/cpu/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) #include "xla/pjrt/cpu/mpi_collectives.h" @@ -93,7 +93,7 @@ limitations under the License. #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/mlir.h" #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep -#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" #include "xla/python/outfeed_receiver_py.h" @@ -197,6 +197,7 @@ NB_MODULE(xla_extension, m_nb) { .value("U32", U32) .value("U64", U64) .value("F16", F16) + .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index f49cfefc85c321..fe673593b442df 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -279,6 +279,7 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType = _xla.PrimitiveType bfloat16 = ml_dtypes.bfloat16 +float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -297,6 +298,7 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), + PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 8731080c99b52a..ac80e8fb0918dd 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -59,6 +59,7 @@ _version: int mlir_api_version: int bfloat16: type[numpy.generic] +float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index d4406155eadd42..cca05c3a37844d 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -54,6 +54,7 @@ xla_client._xla.jax_jit.global_state().enable_memories = False bfloat16 = xla_client.bfloat16 +float8_e4m3 = xla_client.float8_e4m3 float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -138,7 +139,7 @@ def TestFactory(xla_backend, # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. - standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3, float8_e5m2] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 93088d1ae9c06f..65571e53049a6a 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -73,6 +73,7 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType F8E4M3B11FNUZ: PrimitiveType F8E4M3FNUZ: PrimitiveType diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index d1fc6de1e13ca0..4a189527cd4c12 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -594,6 +594,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( #endif FloatSupport f8e5m2_support(F8E5M2, F16); pipeline.AddPass(&f8e5m2_support); + FloatSupport f8e4m3_support(F8E4M3, F16); + pipeline.AddPass(&f8e4m3_support); FloatSupport f8e4m3fn_support(F8E4M3FN, F16); pipeline.AddPass(&f8e4m3fn_support); FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); @@ -1535,8 +1537,9 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { }); for (const auto& kernel : symbols.kernels) { - TraceMe trace( - [&] { return TraceMeEncode("Kernel", {{"name", kernel.name}}); }); + TraceMe trace([&] { + return TraceMeEncode("Kernel", {{"name", kernel.name}}); + }); if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) { return Internal("Failed to find compiled symbol for kernel %s", kernel.name); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index c0c956a32dc0b1..a72879521e9da5 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -71,7 +71,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ + // F8E4M3B11FNUZ, F8E4M3 default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 9b76745911fd5d..1d81cc6a7748bc 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -220,6 +220,59 @@ absl::StatusOr EmitReducePrecisionIR( return result; } +llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, + llvm::Value* f8_bits, + llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + // F16 values that are halfway between denormal F8 values. This is used to + // determine how to round to denormal F8 values. + const int halfway_points[8] = { + 0x1400, // 2**-10; halfway between [0, 2**-9] + 0x1A00, // 1.5 * 2**-9; halfway between [2**-9, 2**-8] + 0x1D00, // 1.25 * 2**-8; halfway between [2**-8, 1.5 * 2**-8] + 0x1F00, // 1.75 * 2**-8; halfway between [1.5 * 2**-8, 2**-7] + 0x2080, // 1.125 * 2**-7; halfway between [2**-7, 1.25 * 2**-7] + 0x2180, // 1.375 * 2**-7; halfway between [1.25 * 2**-7, 1.5 * 2**-7] + 0x2280, // 1.625 * 2**-7; halfway between [1.5 * 2**-7, 1.75 * 2**-7] + 0x2380, // 1.875 * 2**-7; halfway between [1.75 * 2**-7, 2**-6] + }; + + // Handle case where output is denormal. If we're rounding to a denormal + // value, ignore the current value of f8_bits and set it to the correct + // denormal value. We emit the equivalent of the following: + // + // if (f16_abs_bits <= halfway_points[0]) { + // f8_bits = 0; + // } else if (f16_abs_bits < halfway_points[1]) { + // f8_bits = 1; + // } else if (f16_abs_bits <= halfway_points[2]) { + // ... // More if-else statements. The comparisons alternate between <= + // ... // and < to handle round-to-even properly. + // } else if (f16_abs_bits < halfway_points[7]) { + // f8_bits = 7; + // } + for (int i = ABSL_ARRAYSIZE(halfway_points) - 1; i >= 0; i--) { + Value* comparison; + if (i % 2 == 0) { + comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); + } else { + comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); + } + f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); + } + return f8_bits; +} + absl::StatusOr EmitF16ToF8e5m2(llvm::Value* f16_value, llvm::IRBuilder<>* b) { TF_ASSIGN_OR_RETURN( @@ -242,6 +295,180 @@ llvm::Value* EmitF8e5m2ToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { return b->CreateBitCast(shifted, b->getHalfTy()); } +absl::StatusOr EmitF16ToF8e4m3(llvm::Value* f16_value, + llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f16_as_int = bitcast(f16_value, int) + // f16_abs_bits = f16_as_int & 0x7FFF + Value* f16_as_int = b->CreateBitCast(f16_value, i16_type); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_as_int, i16_const(0x7FFF)); + + // Get the sign. + // f8_sign = (f16_as_int & 0x8000) >> 8 + Value* f16_sign = b->CreateAnd(f16_as_int, i16_const(0x8000)); + f16_sign = b->CreateLShr(f16_sign, i16_const(8)); + Value* f8_sign = b->CreateTrunc(f16_sign, i8_type); + + // Truncate the mantissa to 3 bits. + // Denormal values are not handled properly here and are + // dealt with later in this function. + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/4, + /*dest_mantissa_bits=*/3, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); // Crash OK + Value* f16_reduced = f16_reduced_statusor.value(); + f16_reduced = b->CreateBitCast(f16_reduced, i16_type); + + // Remove the sign bit. + // f16_reduced = f16_reduced & 0x7FFF + f16_reduced = b->CreateAnd(f16_reduced, i16_const(0x7FFF)); + + // Bits of the F16 representation of the smallest F8E4M3 normal value. 2^-6 + constexpr int min_normal_value = 0x2400; + + // Round values smaller than the smallest F8 normal value up to the smallest + // F8 normal value. The case where we round to a denormal value is handled + // later. + // f16_reduced = max(f16_reduced, min_normal_value) + f16_reduced = b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(min_normal_value)), + i16_const(min_normal_value), f16_reduced); + + // Right shift E5 exponent's leftmost bit to convert from E5 to E4 format. + // For example, 10011 becomes 01011, another example, 01011 becomes 00011 + // (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1) + f16_reduced = + b->CreateOr(b->CreateAnd(f16_reduced, i16_const(0x9FFF)), + b->CreateLShr(b->CreateAnd(f16_reduced, i16_const(0x4000)), + i16_const(1))); + + constexpr int mantissa_bits_difference = 7; // 10 - 3 + + // Shift to convert to F8. + // f16_reduced = f16_reduced >> mantissa_bits_difference; + f16_reduced = b->CreateLShr(f16_reduced, i16_const(mantissa_bits_difference)); + + Value* f8_bits = b->CreateTrunc(f16_reduced, i8_type); + + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); + + // Set the sign bit. + // f8_bits |= f8_sign + f8_bits = b->CreateOr(f8_bits, f8_sign); + return f8_bits; +} + +llvm::Value* EmitF8e4m3ToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f8_as_int = bitcast(f16_value, int) + // f8_abs_bits = f8_as_int & 0x7F + Value* f8_as_int = b->CreateBitCast(f8_value, i8_type); + Value* f8_abs_bits = b->CreateAnd(f8_as_int, i8_const(0x7F)); + + // We assume below that the value is neither NaN nor denormal. If it NaN or + // denormal, the output is set to NaN or zero at the end using Select + // instructions. + + // Get the sign: + // f16_sign = (f8_as_int & 0x80) << 8 + Value* f8_sign = b->CreateAnd(f8_as_int, i8_const(0x80)); + Value* f16_sign = b->CreateZExt(f8_sign, i16_type); + f16_sign = b->CreateShl(f16_sign, i16_const(8)); + + constexpr int exponent_bias_difference = 15 - 7; + constexpr int f16_mantissa_bits = 10; + constexpr int f8_mantissa_bits = 3; + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int f8_mantissa_mask = (1 << f8_mantissa_bits) - 1; + + // Get the exponent: + // f8_exponent = (f8_as_int & 0x78) >> f8_mantissa_bits + Value* f8_exponent_bits = b->CreateAnd(f8_as_int, i8_const(0x78)); + Value* f8_exponent = + b->CreateLShr(f8_exponent_bits, i8_const(f8_mantissa_bits)); + + // Adjust the exponent by adding the difference in exponent bias: + // f16_exponent = (f8_exponent + exponent_bias_difference) + // << f16_mantissa_bits + Value* f16_exponent = + b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); + f16_exponent = b->CreateZExt(f16_exponent, i16_type); + f16_exponent = b->CreateShl(f16_exponent, i16_const(f16_mantissa_bits)); + + // Set output exponent to 11111 if input exponent is 1111 (Inf or NaN) + // 0.1111.000 is 0x78 + // 0.11111.000000000000 is 0x7C00 + Value* is_exp_1111 = b->CreateICmpEQ(f8_exponent_bits, i8_const(0x78)); + f16_exponent = b->CreateSelect(is_exp_1111, i16_const(0x7C00), f16_exponent); + + // Get the mantissa: + // f16_mantissa = (f8_mantissa & f8_mantissa_mask) + // << mantissa_bits_difference + Value* f8_mantissa = b->CreateAnd(f8_as_int, i8_const(f8_mantissa_mask)); + Value* f16_mantissa = b->CreateZExt(f8_mantissa, i16_type); + f16_mantissa = + b->CreateShl(f16_mantissa, i16_const(mantissa_bits_difference)); + + // Combine the exponent and mantissa: + // f16_as_int = f16_exponent | f16_mantissa + Value* f16_as_int = b->CreateOr(f16_exponent, f16_mantissa); + + // Map from F8 denormal value to F16 value. + int f8_denormal_to_f16[8] = { + 0x0000, // 0 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 + }; + + // If the F8 value is denormal, use the map above to determine the correct F16 + // value. + // if (f8_abs_bits < 8) { f16_as_int = f8_denormal_to_f16[f8_abs_bits]; } + for (int i = 0; i < ABSL_ARRAYSIZE(f8_denormal_to_f16); i++) { + Value* is_denormal_value = b->CreateICmpEQ(f8_abs_bits, i8_const(i)); + f16_as_int = b->CreateSelect(is_denormal_value, + i16_const(f8_denormal_to_f16[i]), f16_as_int); + } + + // Set the sign bit. + // f16_as_int |= f16_sign + f16_as_int = b->CreateOr(f16_as_int, f16_sign); + return b->CreateBitCast(f16_as_int, b->getHalfTy()); +} + llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { using llvm::APInt; using llvm::Value; @@ -322,42 +549,8 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { b->CreateICmpUGT(f16_abs_bits, i16_const(max_finite_value)), i8_const(0x7F), f8_bits); - // F16 values that are halfway between denormal F8 values. This is used to - // determine how to round to denormal F8 values. - const int halfway_points[8] = { - 0x1400, // 2**-10; halfway between [0, 2**-9] - 0x1A00, // 1.5 * 2**-9; halfway between [2**-9, 2**-8] - 0x1D00, // 1.25 * 2**-8; halfway between [2**-8, 1.5 * 2**-8] - 0x1F00, // 1.75 * 2**-8; halfway between [1.5 * 2**-8, 2**-7] - 0x2080, // 1.125 * 2**-7; halfway between [2**-7, 1.25 * 2**-7] - 0x2180, // 1.375 * 2**-7; halfway between [1.25 * 2**-7, 1.5 * 2**-7] - 0x2280, // 1.625 * 2**-7; halfway between [1.5 * 2**-7, 1.75 * 2**-7] - 0x2380, // 1.875 * 2**-7; halfway between [1.75 * 2**-7, 2**-6] - }; - - // Handle case where output is denormal. If we're rounding to a denormal - // value, ignore the current value of f8_bits and set it to the correct - // denormal value. We emit the equivalent of the following: - // - // if (f16_abs_bits <= halfway_points[0]) { - // f8_bits = 0; - // } else if (f16_abs_bits < halfway_points[1]) { - // f8_bits = 1; - // } else if (f16_abs_bits <= halfway_points[2]) { - // ... // More if-else statements. The comparisons alternate between <= - // ... // and < to handle round-to-even properly. - // } else if (f16_abs_bits < halfway_points[7]) { - // f8_bits = 7; - // } - for (int i = ABSL_ARRAYSIZE(halfway_points) - 1; i >= 0; i--) { - Value* comparison; - if (i % 2 == 0) { - comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); - } else { - comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); - } - f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); - } + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); // Set the sign bit. // f8_bits |= f8_sign @@ -408,7 +601,7 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { b->CreateLShr(f8_exponent_bits, i8_const(f8_mantissa_bits)); // Adjust the exponent by adding the difference in exponent bias: - // f16_exponent = (f8_exopnent + exponent_bias_difference) + // f16_exponent = (f8_exponent + exponent_bias_difference) // << f16_mantissa_bits Value* f16_exponent = b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); @@ -604,6 +797,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F8E4M3) { + return EmitF16ToF8e4m3( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } if (to_type == F8E4M3FN) { return EmitF16ToF8e4m3fn( EmitIntegralToFloating(operand_value, from_type, F16, module_, @@ -789,6 +988,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E4M3) { + TF_RET_CHECK(to_type != F8E4M3); + operand_value = EmitF8e4m3ToF16(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E4M3FN) { TF_RET_CHECK(to_type != F8E4M3FN); operand_value = EmitF8e4m3fnToF16(operand_value, b_); @@ -844,6 +1051,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e5m2(operand_value, b_); } + if (to_type == F8E4M3) { + // Cast to F16 first. Casts to F8E4M3 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e4m3(operand_value, b_); + } if (to_type == F8E4M3FN) { // Cast to F16 first. Casts to F8E4M3FN must be from F16. if (from_type != F16) { @@ -1391,6 +1606,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( if (operand_type == F8E5M2) { lhs_value = EmitF8e5m2ToF16(lhs_value, b_); rhs_value = EmitF8e5m2ToF16(rhs_value, b_); + } else if (operand_type == F8E4M3) { + lhs_value = EmitF8e4m3ToF16(lhs_value, b_); + rhs_value = EmitF8e4m3ToF16(rhs_value, b_); } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index fe3a1041933cb5..a8be22107422a2 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -39,6 +39,8 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F8E4M3: + return &llvm::APFloat::Float8E4M3(); case F8E4M3B11FNUZ: return &llvm::APFloat::Float8E4M3B11FNUZ(); case F8E4M3FN: @@ -67,6 +69,7 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, PrimitiveType type) { switch (type) { + case F8E4M3: case F8E4M3B11FNUZ: case F8E4M3FN: case F8E4M3FNUZ: diff --git a/xla/service/float_normalization_test.cc b/xla/service/float_normalization_test.cc index 8476b38a3b8592..9c393f05bb43e3 100644 --- a/xla/service/float_normalization_test.cc +++ b/xla/service/float_normalization_test.cc @@ -500,6 +500,36 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) { EXPECT_EQ(root->operand(0)->shape().element_type(), U16); } +TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e4m3) { + auto builder = HloComputation::Builder(TestName()); + Shape f16_shape = ShapeUtil::MakeShape(F16, {2, 4}); + Shape f8_shape = ShapeUtil::MakeShape(F8E4M3, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f16_shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, f8_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f16_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(f8_shape, HloOpcode::kMultiply, a, b)); + + HloInstruction* mul1 = builder.AddInstruction( + HloInstruction::CreateBinary(f8_shape, HloOpcode::kMultiply, mul0, c)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get(), F8E4M3, F16)); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), mul1); + EXPECT_EQ(mul0->shape().element_type(), F16); + EXPECT_EQ(mul1->shape().element_type(), F16); + EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); +} + TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { auto builder = HloComputation::Builder(TestName()); Shape f16_shape = ShapeUtil::MakeShape(F16, {2, 4}); diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 618c7ce38f40b3..e6ce2ec20734ff 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1393,6 +1393,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Lambdas and related constants: const GpuFloatSupport bf16_support(gpu_version, BF16); const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16); + const GpuFloatSupport f8e4m3_support(gpu_version, F8E4M3, F16); const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16); const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); @@ -1402,6 +1403,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass("float_normalization"); sub_pipeline.AddPass(&bf16_support); sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3_support); sub_pipeline.AddPass(&f8e4m3fn_support); sub_pipeline.AddPass(&f8e4m3b11fnuz_support); sub_pipeline.AddPass(&f8e5m2fnuz_support); diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 34b5c703798c23..f450bdf5d0e18f 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -95,6 +95,14 @@ TEST_F(FloatConversionTest, F16ToF8E5M2) { ErrorSpec{1e-5, 1e-5})); } +TEST_F(FloatConversionTest, F16ToF8E4M3) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + TEST_F(FloatConversionTest, F16ToF8E4M3FN) { EXPECT_TRUE(RunAndCompare(R"(ENTRY m { %p = f16[] parameter(0) diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 27630b674d2ce4..06baf691a291a6 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -200,6 +200,7 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt16Ty(module->getContext()); case F8E5M2: case F8E5M2FNUZ: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E4M3FNUZ: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index ebac59ba7c4eae..939dca71a84fbd 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,10 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E4M3; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E4M3FN; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 951b2f6e147cd8..6260711030d88f 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -66,6 +66,7 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6f931aeb6324fd..1e49d746f67845 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -46,6 +46,8 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { switch (dtype) { case PrimitiveType::F8E5M2: return DataType::kF8E5M2; + case PrimitiveType::F8E4M3: + return DataType::kF8E4M3; case PrimitiveType::F8E4M3FN: return DataType::kF8E4M3FN; case PrimitiveType::F8E5M2FNUZ: @@ -79,6 +81,8 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { switch (dtype) { case DataType::kF8E5M2: return PrimitiveType::F8E5M2; + case DataType::kF8E4M3: + return PrimitiveType::F8E4M3; case DataType::kF8E4M3FN: return PrimitiveType::F8E4M3FN; case DataType::kF8E5M2FNUZ: @@ -141,6 +145,7 @@ absl::StatusOr GetBlasComputationType( if (algorithm == xla::PrecisionConfig::ALG_UNSET) { switch (output_dtype) { case PrimitiveType::F8E5M2: // fall-through + case PrimitiveType::F8E4M3: // fall-through case PrimitiveType::F8E4M3FN: // fall-through case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index a59c935614cd8f..4d6f89b56cf57a 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -35,8 +35,9 @@ absl::Status ToStatus(hipblasStatus_t status, const char* prefix) { hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: + case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: - LOG(FATAL) << "hipblaslt does not support F8E5M2 and F8E4M3FN"; + LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3 and F8E4M3FN"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index 87088ae974f5f1..9e2ee1347a5e64 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -1405,8 +1405,9 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types(&builder, {2.0f}, {}, error_spec_); } +TEST_F(ConstantsTest, OneCellF8e4m3) { + std::vector constant = {tsl::float8_e4m3{2.0}}; + + XlaBuilder builder(TestName()); + auto c = ConstantR1(&builder, constant); + // F8 outputs are not yet supported so convert to F32 + ConvertElementType(c, F32); + + ComputeAndCompareR1(&builder, {2.0f}, {}, error_spec_); +} + +TEST_F(ConstantsTest, OneCellF8e4m3fn) { + std::vector constant = {tsl::float8_e4m3fn{2.0}}; + + XlaBuilder builder(TestName()); + auto c = ConstantR1(&builder, constant); + // F8 outputs are not yet supported so convert to F32 + ConvertElementType(c, F32); + + ComputeAndCompareR1(&builder, {2.0f}, {}, error_spec_); +} + TEST_F(ConstantsTest, OneCellF8e4m3b11fnuz) { std::vector constant = { tsl::float8_e4m3b11fnuz{2.0}}; diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index bddf0c3b9acc4e..621e5e1e5ed848 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,9 +54,9 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); TEST_F(ConvertTest, ConvertR1S32ToR1S32) { @@ -878,6 +878,105 @@ XLA_TEST_F(ConvertTest, ConvertF8e5m2BF16RoundtripExhaustive3) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +XLA_TEST_F(ConvertTest, ConvertF16F8e4m3Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFCp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest normal + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F16); + const bool saved = + execution_options_.debug_options().xla_allow_excess_precision(); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + false); + ComputeAndCompareR1(&builder, expected_roundtrip, {}); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + saved); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3F16RoundtripExhaustive) { + // Convert from FP8 to FP16, then back to FP8 + XlaBuilder builder(TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); + xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); + ConvertElementType(all_f8_as_f16, F8E4M3); + + // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. + // Round-tripping a NaN will turn it into a quiet NaN and doesn't necessarily + // preserve the payload. + ComputeAndCompareR1(&builder, all_f8, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3F16RoundtripExhaustive2) { + // Convert from F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E4M3); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3BF16RoundtripExhaustive3) { + // Convert from BF16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_bf16_to_f8, F8E4M3); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { // Convert from FP16 to FP8, then back to FP16 XlaBuilder builder(TestName()); @@ -1246,8 +1345,8 @@ XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive) { execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( false); - for (auto type : {F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { + for (auto type : {F8E4M3B11FNUZ, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E5M2, + F8E5M2FNUZ, F16, BF16, F32, F64}) { xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); xla::XlaOp all_f8_as_type = ConvertElementType(all_f8_as_f8, type); @@ -1285,8 +1384,8 @@ XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive3) { Eigen::numext::bit_cast(static_cast(i))); } - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { + for (auto type : {F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, + F8E5M2FNUZ, F16, BF16, F32, F64}) { xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); ConvertElementType(all_f8_as_f8, type); @@ -1452,8 +1551,8 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive) { execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( false); - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { + for (auto type : {F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, + F8E5M2FNUZ, F16, BF16, F32, F64}) { xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); xla::XlaOp all_f8_as_type = ConvertElementType(all_f8_as_f8, type); @@ -1491,8 +1590,8 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive3) { Eigen::numext::bit_cast(static_cast(i))); } - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { + for (auto type : {F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, + F8E5M2FNUZ, F16, BF16, F32, F64}) { xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); ConvertElementType(all_f8_as_f8, type); @@ -1540,6 +1639,16 @@ XLA_TEST_F(ConvertTest, ConvertF8e5m2ToPred) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ConvertTest, ConvertF8e4m3ToPred) { + XlaBuilder builder(TestName()); + using F8 = tsl::float8_e4m3; + auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); + ConvertElementType(a, PRED); + + std::array expected = {false, true, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(ConvertTest, ConvertF8e4m3fnToPred) { XlaBuilder builder(TestName()); using F8 = tsl::float8_e4m3fn; diff --git a/xla/tests/float8_test.cc b/xla/tests/float8_test.cc index ab5debea32355b..20bc55455ff9ed 100644 --- a/xla/tests/float8_test.cc +++ b/xla/tests/float8_test.cc @@ -27,11 +27,12 @@ limitations under the License. namespace xla { namespace { -// Test FP8 floating-point types (F8E5M2, F8E4M3FN) +// Test FP8 floating-point types (F8E5M2, F8E4M3, F8E4M3FN) template class Float8Test : public ClientLibraryTestBase {}; -using DataTypes = ::testing::Types; +using DataTypes = + ::testing::Types; TYPED_TEST_SUITE(Float8Test, DataTypes); XLA_TYPED_TEST(Float8Test, ScalarOperation) { diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 780968098cf32b..4d5dbb8dc159ca 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -120,6 +120,7 @@ enum PrimitiveType { C64, C128, F8E5M2, + F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, @@ -127,15 +128,18 @@ enum PrimitiveType { }; const std::vector& primitive_strings() { - static auto vec = - new std::vector({"s2", "s4", "s8", - "s16", "s32", "s64", - "u2", "u4", "u8", - "u16", "u32", "u64", - "f16", "bf16", "f32", - "f64", "c64", "c128", - "f8e5m2", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz"}); + static auto vec = new std::vector({"s2", "s4", + "s8", "s16", + "s32", "s64", + "u2", "u4", + "u8", "u16", + "u32", "u64", + "f16", "bf16", + "f32", "f64", + "c64", "c128", + "f8e5m2", "f8e4m3", + "f8e4m3fn", "f8e4m3b11fnuz", + "f8e5m2fnuz", "f8e4m3fnuz"}); return *vec; } @@ -413,6 +417,7 @@ void Fill(void* buffer, const ArrayShape& shape) { return FillFloatT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: @@ -469,6 +474,7 @@ void Display(const void* buffer, const ArrayShape& shape) { return DisplayT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlo b/xla/translate/hlo_to_mhlo/tests/import.hlo index 0c175bc850e32e..c010fd5e385843 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/translate/hlo_to_mhlo/tests/import.hlo @@ -415,6 +415,9 @@ add { // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> %constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + %constant.12 = f8e4m3[4] constant({1, 2, 3, 4}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -524,7 +527,13 @@ add { %convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10) // CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32> - ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + + // CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3> + %convert.13 = f8e4m3[4] convert(f32[4] %convert.12) + + // CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32> + ROOT %convert.14 = f32[4] convert(f8e4m3[4] %convert.13) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/translate/mhlo_to_hlo/tests/export.mlir b/xla/translate/mhlo_to_hlo/tests/export.mlir index 6672e62daf04de..c8c474f9ea435b 100644 --- a/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -600,6 +600,9 @@ func.func @main() { // CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4}) %cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: f8e4m3[4] constant({1, 2, 3, 4}) + %cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + func.return } @@ -729,7 +732,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32> %6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ> %7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32> - func.return %7 : tensor<2xf32> + %8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3> + %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> + func.return %9 : tensor<2xf32> } // CHECK: ENTRY @@ -741,7 +746,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]]) // CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]]) // CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]]) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) +// CHECK: ROOT %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) // ----- diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index 46fa640ee62298..0f86622236f1eb 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -70,6 +70,7 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index 717ab3e462a7bf..9d44129a51b5f4 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,8 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float8_e4m3 = + py::dtype::from_args(ml_dtypes.attr("float8_e4m3")).num(); numpy_dtypes.float8_e4m3fn = py::dtype::from_args(ml_dtypes.attr("float8_e4m3fn")).num(); numpy_dtypes.float8_e5m2 = @@ -81,6 +83,7 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || numpy_dtypes.float8_e4m3fnuz == NPY_NOTYPE || numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index bf9eab2200a76b..6667cad3891c13 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,7 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float8_e4m3; int float8_e4m3fn; int float8_e4m3b11fnuz; int float8_e4m3fnuz; diff --git a/xla/util.cc b/xla/util.cc index 9b1a6db1fa22c0..24a6f7a634e593 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -137,7 +137,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { static_assert(!std::is_same::value, - "RoundTripNanPayload does not support E4M3"); + "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FNUZ"); static_assert(!std::is_same::value, @@ -168,6 +168,12 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e4m3 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(tsl::float8_e4m3fnuz value) { std::string result = GenericRoundTripFpToString(value); return result; diff --git a/xla/util.h b/xla/util.h index 325080f0f08201..c2523f9289931e 100644 --- a/xla/util.h +++ b/xla/util.h @@ -252,8 +252,8 @@ absl::Status AppendStatus(absl::Status prior, absl::string_view context); /*Deduction guide to make variadic arguments play nice with default */ \ /* absl::SourceLocation argument. */ \ template \ - error_type(const absl::FormatSpec& format, \ - Args&&...) -> error_type; + error_type(const absl::FormatSpec& format, Args&&...) \ + -> error_type; #if defined(PLATFORM_GOOGLE) #define XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(error_type) \ @@ -420,6 +420,9 @@ std::string VectorString(const std::initializer_list& c) { std::string RoundTripFpToString(tsl::float8_e5m2 value); // Returns a string which can losslessly round trip to a float8 E4M3. +std::string RoundTripFpToString(tsl::float8_e4m3 value); + +// Returns a string which can losslessly round trip to a float8 E4M3FN. std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. diff --git a/xla/util_test.cc b/xla/util_test.cc index 707696ea1c3a99..6cd93cc6b067c7 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -130,6 +130,12 @@ TEST(UtilTest, RoundTripFpToString) { EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( true, QuietNanWithoutPayload())), "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); EXPECT_EQ( RoundTripFpToString(std::numeric_limits::quiet_NaN()), "nan"); @@ -237,6 +243,18 @@ TEST(UtilTest, TotalOrder_F8E5M2) { } } +TEST(UtilTest, TotalOrder_F8E4M3) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e4m3 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e4m3 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E4M3FN) { for (int a = 0; a < 256; ++a) { tsl::float8_e4m3fn x = diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 335b59e2064d23..72aa3bb90bac93 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -66,6 +66,9 @@ enum PrimitiveType { // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the // existing IEEE types. // + // F8E4M3 has 4 exponent bits and 3 mantissa bits, and is similar to the + // existing IEEE types. + // // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only // Finite and NaN values are supported. Unlike IEEE types, infinities are not // supported. NaN is represented when the exponent and mantissa bits are all @@ -81,6 +84,7 @@ enum PrimitiveType { // properly in most cases. // TODO(b/259609697): Fully support FP8. F8E5M2 = 19; + F8E4M3 = 28; F8E4M3FN = 20; F8E4M3B11FNUZ = 23; @@ -126,7 +130,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 28 + // Next = 29 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, @@ -573,12 +577,13 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; bytes f8e5m2s = 19; + bytes f8e4m3s = 28; bytes f8e4m3fns = 20; bytes f8e4m3b11fnuzs = 23; bytes f8e5m2fnuzs = 24; bytes f8e4m3fnuzs = 25; repeated int64 sparse_indices = 14; - // Next = 28 + // Next = 29 } message WindowDimension {