Skip to content

Commit

Permalink
Add support for float8_e4m3
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Sep 18, 2024
1 parent 0651342 commit 7bda957
Show file tree
Hide file tree
Showing 67 changed files with 921 additions and 189 deletions.
2 changes: 2 additions & 0 deletions third_party/tsl/tools/def_file_filter/symbols_pybind.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ tsl::ml_dtypes::RegisterTypes
tsl::ml_dtypes::GetBfloat16Dtype
tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype
tsl::ml_dtypes::GetFloat8E4m3fnDtype
tsl::ml_dtypes::GetFloat8E4m3Dtype
tsl::ml_dtypes::GetFloat8E5m2Dtype
tsl::ml_dtypes::GetBfloat16TypeNum
tsl::ml_dtypes::GetFloat8E4m3b11fnuzTypeNum
tsl::ml_dtypes::GetFloat8E4m3fnTypeNum
tsl::ml_dtypes::GetFloat8E4m3TypeNum
tsl::ml_dtypes::GetFloat8E5m2TypeNum

[//tensorflow/python:py_func_lib] # py_func
Expand Down
3 changes: 2 additions & 1 deletion third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ limitations under the License.
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes

namespace tsl {
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
Expand Down
1 change: 1 addition & 0 deletions third_party/tsl/tsl/protobuf/dnn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum DataType {
kF8E5M2FNUZ = 10;
kF8E4M3FNUZ = 11;
kInt64 = 12;
kF8E4M3 = 13;
}

// Describes how a convolution input or output layer's data is formatted.
Expand Down
14 changes: 14 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3Fn) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3fn>(1.0, 3.5, 3, 2);

Expand Down
9 changes: 5 additions & 4 deletions xla/client/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ XlaOp IsNegZero(XlaOp operand) {
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F8E4M3:
case F8E5M2:
case F8E4M3FN:
case F8E4M3B11FNUZ:
Expand Down Expand Up @@ -973,8 +974,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1026,8 +1027,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "TOKEN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E4M3:
return os << "F8E4M3";
case XLA_FFI_DataType_F8E4M3FN:
return os << "F8E4M3FN";
case XLA_FFI_DataType_F8E4M3B11FNUZ:
Expand Down
1 change: 1 addition & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ typedef enum {
XLA_FFI_DataType_C128 = 18,
XLA_FFI_DataType_TOKEN = 17,
XLA_FFI_DataType_F8E5M2 = 19,
XLA_FFI_DataType_F8E4M3 = 28,
XLA_FFI_DataType_F8E4M3FN = 20,
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum class DataType : uint8_t {
C128 = XLA_FFI_DataType_C128,
TOKEN = XLA_FFI_DataType_TOKEN,
F8E5M2 = XLA_FFI_DataType_F8E5M2,
F8E4M3 = XLA_FFI_DataType_F8E4M3,
F8E4M3FN = XLA_FFI_DataType_F8E4M3FN,
F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ,
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
Expand All @@ -98,6 +99,7 @@ inline constexpr DataType C64 = DataType::C64;
inline constexpr DataType C128 = DataType::C128;
inline constexpr DataType TOKEN = DataType::TOKEN;
inline constexpr DataType F8E5M2 = DataType::F8E5M2;
inline constexpr DataType F8E4M3 = DataType::F8E4M3;
inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
Expand All @@ -117,6 +119,7 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::S8:
case DataType::U8:
case DataType::F8E5M2:
case DataType::F8E4M3:
case DataType::F8E4M3FN:
case DataType::F8E4M3B11FNUZ:
case DataType::F8E5M2FNUZ:
Expand Down
3 changes: 3 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ),
encoded(DataType::F8E4M3B11FNUZ));
Expand Down Expand Up @@ -179,6 +180,8 @@ TEST(FfiTest, DataTypeByteWidth) {

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
ByteWidth(DataType::F8E4M3));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN),
ByteWidth(DataType::F8E4M3FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ),
Expand Down
1 change: 1 addition & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
Expand Down
4 changes: 2 additions & 2 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

// Test F8E4M3 floating-point types (F8E4M3FN)
// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN)
template <typename T>
class FP8E4M3DistanceTest : public ::testing::Test {};

using F8E4M3Types = ::testing::Types<tsl::float8_e4m3fn>;
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);

TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) {
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,7 @@ extern template class HloEvaluatorTypedVisitor<complex64>;
extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
Expand Down
24 changes: 18 additions & 6 deletions xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ bool LiteralProtoHasValues(const LiteralProto& proto) {
!proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() ||
!proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() ||
!proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() ||
!proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() ||
!proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() ||
!proto.f8e4m3fnuzs().empty() || !proto.f16s().empty() ||
!proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() ||
proto.c64s_size() || proto.c128s_size() || proto.preds_size() ||
proto.tuple_literals_size();
!proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() ||
!proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() ||
!proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() ||
!proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() ||
proto.f64s_size() || proto.c64s_size() || proto.c128s_size() ||
proto.preds_size() || proto.tuple_literals_size();
}

// Lazy getter for the interned scalar shape in static storage. We reuse this
Expand Down Expand Up @@ -2258,6 +2258,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
reinterpret_cast<const char*>(data<tsl::float8_e5m2>().data()),
size_bytes_dense());
break;
case F8E4M3:
*proto->mutable_f8e4m3s() = std::string(
reinterpret_cast<const char*>(data<tsl::float8_e4m3>().data()),
size_bytes_dense());
break;
case F8E4M3FN:
*proto->mutable_f8e4m3fns() = std::string(
reinterpret_cast<const char*>(data<tsl::float8_e4m3fn>().data()),
Expand Down Expand Up @@ -2436,6 +2441,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
memcpy(untyped_data(), s.data(), s.size());
break;
}
case F8E4M3: {
const std::string& s(proto.f8e4m3s());
TF_RET_CHECK(data<tsl::float8_e4m3>().size() * sizeof(tsl::float8_e4m3) ==
s.size());
memcpy(untyped_data(), s.data(), s.size());
break;
}
case F8E4M3FN: {
const std::string& s(proto.f8e4m3fns());
TF_RET_CHECK(data<tsl::float8_e4m3fn>().size() *
Expand Down
2 changes: 1 addition & 1 deletion xla/literal_comparison_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace {
template <typename T>
class LiteralComparisonTest : public ::testing::Test {};

using TestedTypes = ::testing::Types<tsl::float8_e4m3fn,
using TestedTypes = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn,
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2>;
TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes);

Expand Down
Loading

0 comments on commit 7bda957

Please sign in to comment.