Skip to content

Commit

Permalink
Add support for float8_e4m3
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Sep 11, 2024
1 parent acef59a commit 88e586c
Show file tree
Hide file tree
Showing 68 changed files with 992 additions and 184 deletions.
4 changes: 2 additions & 2 deletions third_party/py/ml_dtypes/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ float8 varieties, and int4.
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
ML_DTYPES_COMMIT = "24084d9ed2c3d45bf83b7a9bff833aa185bf9172"
ML_DTYPES_SHA256 = "c916a3e6b3d9bdcb476f506fdbbecb6d5e9f21f82f221dfcb42b320b4e85e55a"
ML_DTYPES_COMMIT = "6f02f77c4fa624d8b467c36d1d959a9b49b07900"
ML_DTYPES_SHA256 = "c5b421a3b8549c020582b9be5e9edf8bb6e9d4284cbd44b0babe6640b4af18da"
tf_http_archive(
name = "ml_dtypes",
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
Expand Down
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/py/ml_dtypes/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ float8 varieties, and int4.
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
ML_DTYPES_COMMIT = "24084d9ed2c3d45bf83b7a9bff833aa185bf9172"
ML_DTYPES_SHA256 = "c916a3e6b3d9bdcb476f506fdbbecb6d5e9f21f82f221dfcb42b320b4e85e55a"
ML_DTYPES_COMMIT = "6f02f77c4fa624d8b467c36d1d959a9b49b07900"
ML_DTYPES_SHA256 = "c5b421a3b8549c020582b9be5e9edf8bb6e9d4284cbd44b0babe6640b4af18da"
tf_http_archive(
name = "ml_dtypes",
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
Expand Down
2 changes: 2 additions & 0 deletions third_party/tsl/tools/def_file_filter/symbols_pybind.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ tsl::ml_dtypes::RegisterTypes
tsl::ml_dtypes::GetBfloat16Dtype
tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype
tsl::ml_dtypes::GetFloat8E4m3fnDtype
tsl::ml_dtypes::GetFloat8E4m3Dtype
tsl::ml_dtypes::GetFloat8E5m2Dtype
tsl::ml_dtypes::GetBfloat16TypeNum
tsl::ml_dtypes::GetFloat8E4m3b11fnuzTypeNum
tsl::ml_dtypes::GetFloat8E4m3fnTypeNum
tsl::ml_dtypes::GetFloat8E4m3TypeNum
tsl::ml_dtypes::GetFloat8E5m2TypeNum

[//tensorflow/python:py_func_lib] # py_func
Expand Down
3 changes: 2 additions & 1 deletion third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ limitations under the License.
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes

namespace tsl {
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
Expand Down
1 change: 1 addition & 0 deletions third_party/tsl/tsl/protobuf/dnn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum DataType {
kF8E5M2FNUZ = 10;
kF8E4M3FNUZ = 11;
kInt64 = 12;
kF8E4M3 = 13;
}

// Describes how a convolution input or output layer's data is formatted.
Expand Down
14 changes: 14 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3Fn) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3fn>(1.0, 3.5, 3, 2);

Expand Down
46 changes: 25 additions & 21 deletions xla/client/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ XlaOp IsNegZero(XlaOp operand) {
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F8E5M2:
case F8E4M3:
case F8E4M3FN:
case F8E4M3B11FNUZ:
case F8E5M2FNUZ:
Expand Down Expand Up @@ -315,12 +316,14 @@ XlaOp Erfc(XlaOp x) {
}
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(
x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) {
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
ScalarLike(x, 1) - ErfImpl32Cephes(x));
});
return DoWithUpcastToF32(x,
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) {
return Select(
Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
ScalarLike(x, 1) - ErfImpl32Cephes(x));
});
});
}

Expand Down Expand Up @@ -492,9 +495,10 @@ XlaOp ErfInv(XlaOp x) {
if (shape.element_type() == F64) {
return ErfInv64(x);
}
return DoWithUpcastToF32(
x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) { return ErfInv32(x); });
return DoWithUpcastToF32(x,
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) { return ErfInv32(x); });
});
}

Expand Down Expand Up @@ -622,10 +626,10 @@ XlaOp Lgamma(XlaOp input) {
// F16 and BF16 don't provide sufficient precision for intermediate results
// here (although it's better than you might expect!), so do the
// computations in F32.
return DoWithUpcastToF32(
input,
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
do_it);
return DoWithUpcastToF32(input,
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
do_it);
});
}

Expand Down Expand Up @@ -720,10 +724,10 @@ XlaOp Digamma(XlaOp input) {
auto& b = *input.builder();
return b.ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
return DoWithUpcastToF32(
input,
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
do_it);
return DoWithUpcastToF32(input,
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
do_it);
});
}

Expand Down Expand Up @@ -978,8 +982,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1031,8 +1035,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "TOKEN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E4M3:
return os << "F8E4M3";
case XLA_FFI_DataType_F8E4M3FN:
return os << "F8E4M3FN";
case XLA_FFI_DataType_F8E4M3B11FNUZ:
Expand Down
1 change: 1 addition & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ typedef enum {
XLA_FFI_DataType_C128 = 18,
XLA_FFI_DataType_TOKEN = 17,
XLA_FFI_DataType_F8E5M2 = 19,
XLA_FFI_DataType_F8E4M3 = 28,
XLA_FFI_DataType_F8E4M3FN = 20,
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ enum class DataType : uint8_t {
C128 = XLA_FFI_DataType_C128,
TOKEN = XLA_FFI_DataType_TOKEN,
F8E5M2 = XLA_FFI_DataType_F8E5M2,
F8E4M3 = XLA_FFI_DataType_F8E4M3,
F8E4M3FN = XLA_FFI_DataType_F8E4M3FN,
F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ,
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
Expand All @@ -96,6 +97,7 @@ inline constexpr DataType C64 = DataType::C64;
inline constexpr DataType C128 = DataType::C128;
inline constexpr DataType TOKEN = DataType::TOKEN;
inline constexpr DataType F8E5M2 = DataType::F8E5M2;
inline constexpr DataType F8E4M3 = DataType::F8E4M3;
inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
Expand All @@ -115,6 +117,7 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::S8:
case DataType::U8:
case DataType::F8E5M2:
case DataType::F8E4M3:
case DataType::F8E4M3FN:
case DataType::F8E4M3B11FNUZ:
case DataType::F8E5M2FNUZ:
Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ),
encoded(DataType::F8E4M3B11FNUZ));
Expand Down Expand Up @@ -176,6 +177,8 @@ TEST(FfiTest, DataTypeByteWidth) {

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
ByteWidth(DataType::F8E4M3));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN),
ByteWidth(DataType::F8E4M3FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ),
Expand Down
1 change: 1 addition & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
Expand Down
53 changes: 53 additions & 0 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,59 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

TEST(FPDistanceTest, F8E4M3Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(tsl::float8_e4m3(8.0),
tsl::float8_e4m3(8.0)),
0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(tsl::float8_e4m3(8.0),
tsl::float8_e4m3(13)),
5);

// a & b have different exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(tsl::float8_e4m3(8.0),
tsl::float8_e4m3(6.0)),
4);

// 1 from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(
std::numeric_limits<tsl::float8_e4m3>::denorm_min(),
tsl::float8_e4m3(0)),
1);

// 1 from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(
-std::numeric_limits<tsl::float8_e4m3>::denorm_min(),
tsl::float8_e4m3(0)),
1);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(
-std::numeric_limits<tsl::float8_e4m3>::denorm_min(),
std::numeric_limits<tsl::float8_e4m3>::denorm_min()),
2);

// 1 non denorm from 0 in the positive direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e4m3>(
std::numeric_limits<tsl::float8_e4m3>::min(), tsl::float8_e4m3(0)),
8);

// 1 non denorm from 0 in the negative direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e4m3>(
-std::numeric_limits<tsl::float8_e4m3>::min(), tsl::float8_e4m3(0)),
8);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3>(
-std::numeric_limits<tsl::float8_e4m3>::min(),
std::numeric_limits<tsl::float8_e4m3>::min()),
16);
}

TEST(FPDistanceTest, F8E4M3FNDistance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e4m3fn>(
Expand Down
87 changes: 44 additions & 43 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,51 +520,51 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
absl::Status HandlePower(const HloInstruction* power) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[power],
ElementWiseBinaryOp(power, [](ElementwiseT lhs_el,
ElementwiseT rhs_el) {
// Case 0: 1^x = 1 and x^0 = 1, regardless of X, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
if (lhs_el == ElementwiseT(1) || rhs_el == ElementwiseT(0)) {
return static_cast<ElementwiseT>(1);
}
// Case 1:
// 1. inf^(a + 0i) = inf, if a > 0.
// 2. inf^(a + 0i) = 0, if a < 0.
if constexpr (is_complex_v<ElementwiseT>) {
auto is_positive_infinity = [](ElementwiseT c) {
return c.imag() == 0 && c.real() > 0 && std::isinf(c.real());
};
auto is_positive_real = [](ElementwiseT c) {
return c.real() > 0 && c.imag() == 0;
};
auto is_negative_real = [](ElementwiseT c) {
return c.real() < 0 && c.imag() == 0;
};
if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el)) {
return static_cast<ElementwiseT>(lhs_el);
}
if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) {
return static_cast<ElementwiseT>(0);
}
}
// Case 2:
// Fallback to pow.
if constexpr (std::is_same_v<ElementwiseT, bool>) {
return lhs_el || !rhs_el;
} else if constexpr (std::is_integral_v<ElementwiseT>) {
if constexpr (std::is_signed_v<ElementwiseT>) {
if (rhs_el < static_cast<ElementwiseT>(0)) {
ElementWiseBinaryOp(
power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
// Case 0: 1^x = 1 and x^0 = 1, regardless of X, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
if (lhs_el == ElementwiseT(1) || rhs_el == ElementwiseT(0)) {
return static_cast<ElementwiseT>(1);
}
// Case 1:
// 1. inf^(a + 0i) = inf, if a > 0.
// 2. inf^(a + 0i) = 0, if a < 0.
if constexpr (is_complex_v<ElementwiseT>) {
auto is_positive_infinity = [](ElementwiseT c) {
return c.imag() == 0 && c.real() > 0 && std::isinf(c.real());
};
auto is_positive_real = [](ElementwiseT c) {
return c.real() > 0 && c.imag() == 0;
};
auto is_negative_real = [](ElementwiseT c) {
return c.real() < 0 && c.imag() == 0;
};
if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el)) {
return static_cast<ElementwiseT>(lhs_el);
}
if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) {
return static_cast<ElementwiseT>(0);
}
}
// Case 2:
// Fallback to pow.
if constexpr (std::is_same_v<ElementwiseT, bool>) {
return lhs_el || !rhs_el;
} else if constexpr (std::is_integral_v<ElementwiseT>) {
if constexpr (std::is_signed_v<ElementwiseT>) {
if (rhs_el < static_cast<ElementwiseT>(0)) {
return static_cast<ElementwiseT>(
lhs_el == static_cast<ElementwiseT>(1) ? 1 : 0);
}
}
return static_cast<ElementwiseT>(
lhs_el == static_cast<ElementwiseT>(1) ? 1 : 0);
IPow<std::make_unsigned_t<ElementwiseT>>(lhs_el, rhs_el));
} else {
return static_cast<ElementwiseT>(std::pow(lhs_el, rhs_el));
}
}
return static_cast<ElementwiseT>(
IPow<std::make_unsigned_t<ElementwiseT>>(lhs_el, rhs_el));
} else {
return static_cast<ElementwiseT>(std::pow(lhs_el, rhs_el));
}
}));
}));
return absl::OkStatus();
}

Expand Down Expand Up @@ -1743,6 +1743,7 @@ extern template class HloEvaluatorTypedVisitor<complex64>;
extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
Expand Down
Loading

0 comments on commit 88e586c

Please sign in to comment.