Skip to content

Commit

Permalink
Support multiple floating point types in client library test base
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678256749
  • Loading branch information
Google-ML-Automation committed Sep 30, 2024
1 parent 41fdc8a commit 3267ec4
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 77 deletions.
10 changes: 10 additions & 0 deletions xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal,
return ConvertType<float, tsl::float8_e5m2fnuz>(f32_literal);
}

/* static */ Literal LiteralUtil::ConvertF32ToF8E5M2(
const LiteralSlice& f32_literal) {
return ConvertType<float, tsl::float8_e5m2>(f32_literal);
}

/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FN(
const LiteralSlice& f32_literal) {
return ConvertType<float, tsl::float8_e4m3fn>(f32_literal);
}

/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
Expand Down
2 changes: 2 additions & 0 deletions xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class LiteralUtil {
static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF8E5M2(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF8E4M3FN(const LiteralSlice& f32_literal);
static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
static Literal ConvertF32ToS8(const LiteralSlice& f32_literal);
static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
Expand Down
1 change: 1 addition & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ cc_library(
"//xla/client:local_client",
"//xla/client:xla_builder",
"//xla/client:xla_computation",
"//xla/hlo/builder:xla_builder",
"//xla/service:interpreter_plugin", # reference backend
"//xla/service:platform_util",
"//xla/stream_executor",
Expand Down
90 changes: 41 additions & 49 deletions xla/tests/client_library_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ limitations under the License.
#include "xla/client/local_client.h"
#include "xla/client/xla_builder.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/test_helpers.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/logging.h"

namespace xla {
Expand Down Expand Up @@ -291,7 +293,7 @@ absl::StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer(
for (const auto& argument : arguments_) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> owned_argument,
client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
client_->TransferToServer(MaybeConvertLiteralToTestType(argument)));
owning_arguments.push_back(std::move(owned_argument));
arguments.push_back(owning_arguments.back().get());
}
Expand All @@ -315,7 +317,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
for (const auto& argument : arguments_) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> owned_argument,
client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
client_->TransferToServer(MaybeConvertLiteralToTestType(argument)));
owning_arguments.push_back(std::move(owned_argument));
arguments.push_back(owning_arguments.back().get());
}
Expand All @@ -326,20 +328,20 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
ShapeUtil::ElementIsComplex(expected.shape())) {
LOG(WARNING) << "performing exact comparison of floating point numbers";
}
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
// We allow using a float expected literal for non float outputs. In this
// case, we need to convert the expected literal to test_type_.
const Literal* expected_ptr = &expected;
Literal converted_expected;
Shape layout_shape;
if (use_bfloat16()) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
if (test_type_ != F32) {
converted_expected = MaybeConvertLiteralToTestType(expected);
expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
&layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
if (subshape->element_type() == F32) {
subshape->set_element_type(BF16);
subshape->set_element_type(test_type_);
}
});
shape_with_layout = &layout_shape;
Expand Down Expand Up @@ -377,27 +379,27 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
for (const auto& argument : arguments_) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> owned_argument,
client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
client_->TransferToServer(MaybeConvertLiteralToTestType(argument)));
owning_arguments.push_back(std::move(owned_argument));
arguments.push_back(owning_arguments.back().get());
}
}

TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
// We allow using a float expected literal for a non float outputs. In this
// case, we need to convert the expected literal to type_test_.
const Literal* expected_ptr = &expected;
Literal converted_expected;
Shape layout_shape;
if (use_bfloat16()) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
if (test_type_ != F32) {
converted_expected = MaybeConvertLiteralToTestType(expected);
expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
&layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
if (subshape->element_type() == F32) {
subshape->set_element_type(BF16);
subshape->set_element_type(test_type_);
}
});
shape_with_layout = &layout_shape;
Expand Down Expand Up @@ -535,13 +537,11 @@ ClientLibraryTestBase::ComputeValueAndReference(
return std::make_pair(std::move(reference), std::move(result));
}

XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
XlaComputation ClientLibraryTestBase::CreateScalarReluF32() {
XlaBuilder builder("relu");
auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {});
auto shape = ShapeUtil::MakeShape(F32, {});
auto z_value = Parameter(&builder, 0, shape, "z_value");
auto zero = use_bfloat16()
? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
: ConstantR0<float>(&builder, 0.0f);
auto zero = ConstantR0<float>(&builder, 0.0f);
Max(z_value, zero);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
Expand All @@ -550,7 +550,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() {

XlaComputation ClientLibraryTestBase::CreateScalarMax() {
XlaBuilder builder("max");
auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {});
auto shape = ShapeUtil::MakeShape(test_type_, {});
auto x = Parameter(&builder, 0, shape, "x");
auto y = Parameter(&builder, 1, shape, "y");
Max(x, y);
Expand All @@ -559,22 +559,6 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() {
return std::move(computation_status).value();
}

XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
XlaBuilder builder("relu_sensitivity");
auto shape = ShapeUtil::MakeShape(use_bfloat16() ? BF16 : F32, {});
auto activation = Parameter(&builder, 0, shape, "activation");
auto backprop = Parameter(&builder, 1, shape, "backprop");
auto zero = use_bfloat16()
? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
: ConstantR0<float>(&builder, 0.0f);
auto activation_gtz = Gt(activation, zero);
Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);

auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
return std::move(computation_status).value();
}

std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
int rows, int cols, float offset) {
auto array = std::make_unique<Array2D<float>>(rows, cols);
Expand Down Expand Up @@ -605,7 +589,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaBuilder* builder) {
arguments_.push_back(argument.Clone());
return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
MaybeConvertShapeToBfloat16(argument.shape()), "");
MaybeConvertShapeToTestType(argument.shape()), "");
}

XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
Expand All @@ -623,34 +607,42 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
nullptr, builder, data_handle);
}

Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
if (!use_bfloat16()) {
Shape ClientLibraryTestBase::MaybeConvertShapeToTestType(const Shape& shape) {
if (test_type_ == F32) {
return shape;
}
Shape new_shape = shape;
ShapeUtil::ForEachMutableSubshape(&new_shape,
[](Shape* subshape, const ShapeIndex&) {
if (subshape->element_type() == F32) {
subshape->set_element_type(BF16);
}
});
ShapeUtil::ForEachMutableSubshape(
&new_shape, [test_type = test_type_](Shape* subshape, const ShapeIndex&) {
if (subshape->element_type() == F32) {
subshape->set_element_type(test_type);
}
});
return new_shape;
}

Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
Literal ClientLibraryTestBase::MaybeConvertLiteralToTestType(
const Literal& literal) {
if (use_bfloat16()) {
return LiteralUtil::ConvertF32ToBF16(literal);
switch (test_type_) {
case BF16:
return LiteralUtil::ConvertF32ToBF16(literal);
case F32:
return literal.Clone();
case F8E5M2:
return LiteralUtil::ConvertF32ToF8E5M2(literal);
case F8E4M3FN:
return LiteralUtil::ConvertF32ToF8E4M3FN(literal);
default:
LOG(FATAL) << "Unsupported test type: " << test_type_;
}
return literal.Clone();
}

absl::StatusOr<std::unique_ptr<GlobalData>>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(
int64_t parameter_number, const Literal& literal, const std::string& name,
const DeviceHandle* device_handle, XlaBuilder* builder,
XlaOp* data_handle) {
Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
Literal param_literal = MaybeConvertLiteralToTestType(literal);
TF_ASSIGN_OR_RETURN(auto data,
client_->TransferToServer(param_literal, device_handle));
*data_handle =
Expand Down
56 changes: 29 additions & 27 deletions xla/tests/client_library_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ std::vector<TestCase> ExpandUseBfloat16(
return expanded;
}

template <typename TestCase>
std::vector<TestCase> ExpandTestType(
absl::Span<const PrimitiveType> test_type_params,
absl::Span<const TestCase> specs) {
std::vector<TestCase> expanded;
for (const PrimitiveType test_type : test_type_params) {
for (const auto& spec : specs) {
expanded.push_back(spec);
expanded.back().test_type = test_type;
}
}
return expanded;
}

// A client library test establishes an in-process XLA client connection.
class ClientLibraryTestBase : public ::testing::Test {
protected:
Expand Down Expand Up @@ -236,9 +250,8 @@ class ClientLibraryTestBase : public ::testing::Test {
absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// Create scalar operations for use in reductions.
XlaComputation CreateScalarRelu();
XlaComputation CreateScalarReluF32();
XlaComputation CreateScalarMax();
XlaComputation CreateScalarReluSensitivity();

// Special case convenience functions for creating filled arrays.

Expand Down Expand Up @@ -277,7 +290,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// Creates a parameter instruction, transfers the literal for the parameter to
// server, then stores into "data_handle" the global handle for that
// parameter. When the test_type is bfloat16 but the literal has F32 elements,
// the literal will be converted to BF16 before being transferred.
// the literal will be converted to test_type_ before being transferred.
absl::StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
int64_t parameter_number, const Literal& literal, const std::string& name,
XlaBuilder* builder, XlaOp* data_handle);
Expand All @@ -304,7 +317,7 @@ class ClientLibraryTestBase : public ::testing::Test {

// Creates a constant instruction with the given literal. When the test_type
// is bfloat16 but the literal has F32 elements, the literal will be converted
// to BF16 before being transferred.
// to test_type_ before being transferred.
XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);

// Creates a constant instruction with the given array. When the test_type is
Expand Down Expand Up @@ -417,8 +430,8 @@ class ClientLibraryTestBase : public ::testing::Test {
absl::StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments);

// Converts an f32 literal to bf16 if test_type is BF16.
Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
// Converts a literal to the test_type if the literal's type is F32.
Literal MaybeConvertLiteralToTestType(const Literal& literal);

LocalClient* client_;
LocalClient* ref_client_; // To compute reference result.
Expand All @@ -439,10 +452,11 @@ class ClientLibraryTestBase : public ::testing::Test {
verify_output,
const Shape* output_with_layout = nullptr);

// Converts an f32 shape to bf16 if use_bfloat16_ is true.
Shape MaybeConvertShapeToBfloat16(const Shape& shape);
// Converts an f32 shape to test_type_.
Shape MaybeConvertShapeToTestType(const Shape& shape);

// Type to use when running tests.
// Type to use when running tests. By default, we use F32 for historical
// reasons and we rely on the underlying tests to change it.
PrimitiveType test_type_ = F32;

// Arguments to be passed to the computation when it runs.
Expand Down Expand Up @@ -584,9 +598,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64_t parameter_number, const std::string& name,
XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR0(value);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -597,9 +609,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
absl::Span<const NativeT> values, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR1(values);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -610,9 +620,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -623,9 +631,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -636,9 +642,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
const Array4D<NativeT>& array_4d, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand All @@ -649,9 +653,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
const Array<NativeT>& array, int64_t parameter_number,
const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateFromArray(array);
if (use_bfloat16() && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
literal = MaybeConvertLiteralToTestType(literal);
std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
Expand Down
2 changes: 1 addition & 1 deletion xla/tests/reduce_window_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window_dilations=*/param.window_dilation,
/*padding=*/padding);

ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)},
ComputeAndCompare(&b, {MaybeConvertLiteralToTestType(input_literal)},
DefaultErrorSpec());
}
};
Expand Down

0 comments on commit 3267ec4

Please sign in to comment.