Skip to content

Commit

Permalink
Add support for float8_e4m3 and float8_e3m4 types
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Sep 23, 2024
1 parent fa64068 commit 0e9a38c
Show file tree
Hide file tree
Showing 67 changed files with 1,518 additions and 271 deletions.
4 changes: 4 additions & 0 deletions third_party/tsl/tools/def_file_filter/symbols_pybind.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
[//external/local_xla/xla/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8
tsl::ml_dtypes::RegisterTypes
tsl::ml_dtypes::GetBfloat16Dtype
tsl::ml_dtypes::GetFloat8E3m4Dtype
tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype
tsl::ml_dtypes::GetFloat8E4m3fnDtype
tsl::ml_dtypes::GetFloat8E4m3Dtype
tsl::ml_dtypes::GetFloat8E5m2Dtype
tsl::ml_dtypes::GetBfloat16TypeNum
tsl::ml_dtypes::GetFloat8E3m4TypeNum
tsl::ml_dtypes::GetFloat8E4m3b11fnuzTypeNum
tsl::ml_dtypes::GetFloat8E4m3fnTypeNum
tsl::ml_dtypes::GetFloat8E4m3TypeNum
tsl::ml_dtypes::GetFloat8E5m2TypeNum

[//tensorflow/python:py_func_lib] # py_func
Expand Down
4 changes: 3 additions & 1 deletion third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ limitations under the License.
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes

namespace tsl {
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
Expand Down
2 changes: 2 additions & 0 deletions third_party/tsl/tsl/protobuf/dnn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ enum DataType {
kF8E5M2FNUZ = 10;
kF8E4M3FNUZ = 11;
kInt64 = 12;
kF8E4M3 = 13;
kF8E3M4 = 14;
}

// Describes how a convolution input or output layer's data is formatted.
Expand Down
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3Fn) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3fn>(1.0, 3.5, 3, 2);

Expand Down Expand Up @@ -190,6 +204,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) {
}
}

TEST(Array2dTest, LinspaceF8E3M4) {
auto arr = MakeLinspaceArray2D<tsl::float8_e3m4>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
10 changes: 6 additions & 4 deletions xla/client/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ XlaOp IsNegZero(XlaOp operand) {
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F8E3M4:
case F8E4M3:
case F8E5M2:
case F8E4M3FN:
case F8E4M3B11FNUZ:
Expand Down Expand Up @@ -973,8 +975,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1026,8 +1028,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
4 changes: 4 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "TOKEN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E3M4:
return os << "F8E3M4";
case XLA_FFI_DataType_F8E4M3:
return os << "F8E4M3";
case XLA_FFI_DataType_F8E4M3FN:
return os << "F8E4M3FN";
case XLA_FFI_DataType_F8E4M3B11FNUZ:
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ typedef enum {
XLA_FFI_DataType_C128 = 18,
XLA_FFI_DataType_TOKEN = 17,
XLA_FFI_DataType_F8E5M2 = 19,
XLA_FFI_DataType_F8E3M4 = 29,
XLA_FFI_DataType_F8E4M3 = 28,
XLA_FFI_DataType_F8E4M3FN = 20,
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ enum class DataType : uint8_t {
C128 = XLA_FFI_DataType_C128,
TOKEN = XLA_FFI_DataType_TOKEN,
F8E5M2 = XLA_FFI_DataType_F8E5M2,
F8E4M3 = XLA_FFI_DataType_F8E4M3,
F8E4M3FN = XLA_FFI_DataType_F8E4M3FN,
F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ,
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand All @@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64;
inline constexpr DataType C128 = DataType::C128;
inline constexpr DataType TOKEN = DataType::TOKEN;
inline constexpr DataType F8E5M2 = DataType::F8E5M2;
inline constexpr DataType F8E4M3 = DataType::F8E4M3;
inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::S8:
case DataType::U8:
case DataType::F8E5M2:
case DataType::F8E4M3:
case DataType::F8E4M3FN:
case DataType::F8E4M3B11FNUZ:
case DataType::F8E5M2FNUZ:
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ),
encoded(DataType::F8E4M3B11FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
}

TEST(FfiTest, DataTypeByteWidth) {
Expand Down Expand Up @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
ByteWidth(DataType::F8E4M3));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN),
ByteWidth(DataType::F8E4M3FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ),
Expand All @@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) {
ByteWidth(DataType::F8E5M2FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ),
ByteWidth(DataType::F8E4M3FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
ByteWidth(DataType::F8E3M4));
}

TEST(FfiTest, ErrorEnumValue) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
case PrimitiveType::F8E4M3FNUZ:
case PrimitiveType::F8E3M4:
return static_cast<XLA_FFI_DataType>(primitive_type);
default:
DCHECK(false) << "Unsupported primitive type "
Expand Down
61 changes: 57 additions & 4 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

// Test F8E4M3 floating-point types (F8E4M3FN)
// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN)
template <typename T>
class FP8E4M3DistanceTest : public ::testing::Test {};

using F8E4M3Types = ::testing::Types<tsl::float8_e4m3fn>;
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);

TEST(FPDistanceTest, F8E3M4Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(8.0)),
0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(15.5)),
15);

// a & b have different exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(6)),
8);

// 1 from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
tsl::float8_e3m4(0)),
1);

// 1 from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
tsl::float8_e3m4(0)),
1);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
std::numeric_limits<tsl::float8_e3m4>::denorm_min()),
2);

// 1 non denorm from 0 in the positive direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e3m4>(
std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
16);

// 1 non denorm from 0 in the negative direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
16);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::min(),
std::numeric_limits<tsl::float8_e3m4>::min()),
32);
}

TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) {
// a & b are equal, distance should be 0
EXPECT_EQ(
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(8.0)), 0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(13)),
5);
EXPECT_EQ(
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(15.0)), 7);

// a & b have different exponents
EXPECT_EQ(
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,10 +1743,12 @@ extern template class HloEvaluatorTypedVisitor<complex64>;
extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;

} // namespace xla

Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License.

namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;
} // namespace xla
24 changes: 21 additions & 3 deletions xla/hlo/translate/hlo_to_mhlo/tests/import.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,17 @@ add {
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ>
%constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ>
// CHECK: %[[VAL_10:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ>
%constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
// CHECK: %[[VAL_11:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
%constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3>
%constant.12 = f8e4m3[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
%constant.13 = f8e3m4[4] constant({1, 2, 3, 4})
}

// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
Expand Down Expand Up @@ -524,7 +530,19 @@ add {
%convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10)

// CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32>
ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11)
%convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11)

// CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3>
%convert.13 = f8e4m3[4] convert(f32[4] %convert.12)

// CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32>
%convert.14 = f32[4] convert(f8e4m3[4] %convert.13)

// CHECK-NEXT: %12 = mhlo.convert %11 : (tensor<4xf32>) -> tensor<4xf8E3M4>
%convert.15 = f8e3m4[4] convert(f32[4] %convert.14)

// CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32>
ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15)
}

// CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8>
Expand Down
18 changes: 16 additions & 2 deletions xla/hlo/translate/mhlo_to_hlo/tests/export.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,12 @@ func.func @main() {
// CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4})
%cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>

// CHECK: f8e4m3[4] constant({1, 2, 3, 4})
%cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3>

// CHECK: f8e3m4[4] constant({1, 2, 3, 4})
%cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>

func.return
}

Expand Down Expand Up @@ -729,7 +735,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32>
%6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ>
%7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32>
func.return %7 : tensor<2xf32>
%8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3>
%9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32>
%10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4>
%11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32>
func.return %11 : tensor<2xf32>
}

// CHECK: ENTRY
Expand All @@ -741,7 +751,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]])
// CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]])
// CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]])
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]])
// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]])
// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]])
// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]])
// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]])
// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])

// -----

Expand Down
Loading

0 comments on commit 0e9a38c

Please sign in to comment.