diff --git a/xla/literal_util.cc b/xla/literal_util.cc index 2330aca215483b..503a746bb9ac6f 100644 --- a/xla/literal_util.cc +++ b/xla/literal_util.cc @@ -254,6 +254,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return ConvertType(f32_literal); } +/* static */ Literal LiteralUtil::ConvertF32ToF8E5M2( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + +/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FN( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + /* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); diff --git a/xla/literal_util.h b/xla/literal_util.h index 1048682e2d5f4e..f43d1dbeddffa7 100644 --- a/xla/literal_util.h +++ b/xla/literal_util.h @@ -244,6 +244,8 @@ class LiteralUtil { static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal); static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E5M2(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E4M3FN(const LiteralSlice& f32_literal); static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); static Literal ConvertF32ToS8(const LiteralSlice& f32_literal); static Literal ConvertF32ToF64(const LiteralSlice& f32_literal); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index be20715cbe4aed..a71483c689c6f5 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -272,6 +272,7 @@ cc_library( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", "//xla/stream_executor", diff --git a/xla/tests/client_library_test_base.cc b/xla/tests/client_library_test_base.cc index 71b6f9bc175a80..e0599f0f35e151 100644 --- a/xla/tests/client_library_test_base.cc +++ b/xla/tests/client_library_test_base.cc @@ -26,11 +26,13 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/test_helpers.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" namespace xla { @@ -291,7 +293,7 @@ absl::StatusOr ClientLibraryTestBase::ComputeAndTransfer( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } @@ -315,7 +317,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } @@ -326,20 +328,20 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } - // We allow using a float expected literal for a bfloat16 output. In this - // case, we need to convert the expected literal to bfloat16. + // We allow using a float expected literal for non float outputs. In this + // case, we need to convert the expected literal to test_type_. const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16()) { - converted_expected = LiteralUtil::ConvertF32ToBF16(expected); + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); + subshape->set_element_type(test_type_); } }); shape_with_layout = &layout_shape; @@ -377,27 +379,27 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - // We allow using a float expected literal for a bfloat16 output. In this - // case, we need to convert the expected literal to bfloat16. + // We allow using a float expected literal for a non float outputs. In this + // case, we need to convert the expected literal to type_test_. const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16()) { - converted_expected = LiteralUtil::ConvertF32ToBF16(expected); + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); + subshape->set_element_type(test_type_); } }); shape_with_layout = &layout_shape; @@ -535,13 +537,11 @@ ClientLibraryTestBase::ComputeValueAndReference( return std::make_pair(std::move(reference), std::move(result)); } -XlaComputation ClientLibraryTestBase::CreateScalarRelu() { +XlaComputation ClientLibraryTestBase::CreateScalarReluF32() { XlaBuilder builder("relu"); - auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(F32, {}); auto z_value = Parameter(&builder, 0, shape, "z_value"); - auto zero = use_bfloat16() - ? ConstantR0(&builder, static_cast(0.0f)) - : ConstantR0(&builder, 0.0f); + auto zero = ConstantR0(&builder, 0.0f); Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -550,7 +550,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); - auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(test_type_, {}); auto x = Parameter(&builder, 0, shape, "x"); auto y = Parameter(&builder, 1, shape, "y"); Max(x, y); @@ -559,22 +559,6 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { return std::move(computation_status).value(); } -XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { - XlaBuilder builder("relu_sensitivity"); - auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {}); - auto activation = Parameter(&builder, 0, shape, "activation"); - auto backprop = Parameter(&builder, 1, shape, "backprop"); - auto zero = use_bfloat16() - ? ConstantR0(&builder, static_cast(0.0f)) - : ConstantR0(&builder, 0.0f); - auto activation_gtz = Gt(activation, zero); - Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); - - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return std::move(computation_status).value(); -} - std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { auto array = std::make_unique>(rows, cols); @@ -605,7 +589,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { arguments_.push_back(argument.Clone()); return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, - MaybeConvertShapeToBfloat16(argument.shape()), ""); + MaybeConvertShapeToTestType(argument.shape()), ""); } XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, @@ -623,26 +607,34 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( nullptr, builder, data_handle); } -Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { - if (!use_bfloat16()) { +Shape ClientLibraryTestBase::MaybeConvertShapeToTestType(const Shape& shape) { + if (test_type_ == F32) { return shape; } Shape new_shape = shape; - ShapeUtil::ForEachMutableSubshape(&new_shape, - [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); - } - }); + ShapeUtil::ForEachMutableSubshape( + &new_shape, [test_type = test_type_](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == F32) { + subshape->set_element_type(test_type); + } + }); return new_shape; } -Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( +Literal ClientLibraryTestBase::MaybeConvertLiteralToTestType( const Literal& literal) { - if (use_bfloat16()) { - return LiteralUtil::ConvertF32ToBF16(literal); + switch (test_type_) { + case BF16: + return LiteralUtil::ConvertF32ToBF16(literal); + case F32: + return literal.Clone(); + case F8E5M2: + return LiteralUtil::ConvertF32ToF8E5M2(literal); + case F8E4M3FN: + return LiteralUtil::ConvertF32ToF8E4M3FN(literal); + default: + LOG(FATAL) << "Unsupported test type: " << test_type_; } - return literal.Clone(); } absl::StatusOr> @@ -650,7 +642,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { - Literal param_literal = MaybeConvertLiteralToBfloat16(literal); + Literal param_literal = MaybeConvertLiteralToTestType(literal); TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(param_literal, device_handle)); *data_handle = diff --git a/xla/tests/client_library_test_base.h b/xla/tests/client_library_test_base.h index 8610dd6e5ae3cb..2814c7032ea425 100644 --- a/xla/tests/client_library_test_base.h +++ b/xla/tests/client_library_test_base.h @@ -61,6 +61,20 @@ std::vector ExpandUseBfloat16( return expanded; } +template +std::vector ExpandTestType( + absl::Span test_type_params, + absl::Span specs) { + std::vector expanded; + for (const PrimitiveType test_type : test_type_params) { + for (const auto& spec : specs) { + expanded.push_back(spec); + expanded.back().test_type = test_type; + } + } + return expanded; +} + // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: @@ -236,9 +250,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. - XlaComputation CreateScalarRelu(); + XlaComputation CreateScalarReluF32(); XlaComputation CreateScalarMax(); - XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -277,7 +290,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a parameter instruction, transfers the literal for the parameter to // server, then stores into "data_handle" the global handle for that // parameter. When the test_type is bfloat16 but the literal has F32 elements, - // the literal will be converted to BF16 before being transferred. + // the literal will be converted to test_type_ before being transferred. absl::StatusOr> CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, XlaBuilder* builder, XlaOp* data_handle); @@ -304,7 +317,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a constant instruction with the given literal. When the test_type // is bfloat16 but the literal has F32 elements, the literal will be converted - // to BF16 before being transferred. + // to test_type_ before being transferred. XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the test_type is @@ -417,8 +430,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - // Converts an f32 literal to bf16 if test_type is BF16. - Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + // Converts a literal to the test_type if the literal's type is F32. + Literal MaybeConvertLiteralToTestType(const Literal& literal); LocalClient* client_; LocalClient* ref_client_; // To compute reference result. @@ -439,10 +452,11 @@ class ClientLibraryTestBase : public ::testing::Test { verify_output, const Shape* output_with_layout = nullptr); - // Converts an f32 shape to bf16 if use_bfloat16_ is true. - Shape MaybeConvertShapeToBfloat16(const Shape& shape); + // Converts an f32 shape to test_type_. + Shape MaybeConvertShapeToTestType(const Shape& shape); - // Type to use when running tests. + // Type to use when running tests. By default, we use F32 for historical + // reasons and we rely on the underlying tests to change it. PrimitiveType test_type_ = F32; // Arguments to be passed to the computation when it runs. @@ -584,9 +598,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR0(value); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -597,9 +609,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR1(values); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -610,9 +620,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -623,9 +631,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -636,9 +642,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR4Parameter( const Array4D& array_4d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -649,9 +653,7 @@ std::unique_ptr ClientLibraryTestBase::CreateParameter( const Array& array, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateFromArray(array); - if (use_bfloat16() && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; diff --git a/xla/tests/reduce_window_test.cc b/xla/tests/reduce_window_test.cc index c65cd9c9af1969..4417ded2499353 100644 --- a/xla/tests/reduce_window_test.cc +++ b/xla/tests/reduce_window_test.cc @@ -1339,7 +1339,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window_dilations=*/param.window_dilation, /*padding=*/padding); - ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)}, + ComputeAndCompare(&b, {MaybeConvertLiteralToTestType(input_literal)}, DefaultErrorSpec()); } };