Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:ffi] Add support for encoding mlir::DictionaryAttr #17670

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/custom_call.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ struct Range {
int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
StructMember<int64_t>("i64"));
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ limitations under the License.
namespace xla::cpu {
namespace {

using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap;
using AttributesMap = ffi::CallFrameBuilder::AttributesMap;

absl::StatusOr<AttributesMap> ParseAttributes(
absl::string_view backend_config) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,15 @@ cc_library(
hdrs = ["attribute_map.h"],
deps = [
":call_frame",
"//xla:util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
8 changes: 4 additions & 4 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,10 @@ TEST(FfiTest, AttrsAsDictionary) {
}

TEST(FfiTest, DictionaryAttr) {
CallFrameBuilder::FlatAttributesMap dict0;
CallFrameBuilder::AttributesMap dict0;
dict0.try_emplace("i32", 42);

CallFrameBuilder::FlatAttributesMap dict1;
CallFrameBuilder::AttributesMap dict1;
dict1.try_emplace("f32", 42.0f);

CallFrameBuilder::AttributesBuilder attrs;
Expand Down Expand Up @@ -864,7 +864,7 @@ TEST(FfiTest, DictionaryAttr) {
}

TEST(FfiTest, StructAttr) {
CallFrameBuilder::FlatAttributesMap dict;
CallFrameBuilder::AttributesMap dict;
dict.try_emplace("i32", 42);
dict.try_emplace("f32", 42.0f);

Expand Down Expand Up @@ -977,7 +977,7 @@ TEST(FfiTest, EnumAttr) {
}

TEST(FfiTest, WrongEnumAttrType) {
CallFrameBuilder::FlatAttributesMap dict;
CallFrameBuilder::AttributesMap dict;
dict.try_emplace("i32", 42);

CallFrameBuilder::AttributesBuilder attrs;
Expand Down
203 changes: 103 additions & 100 deletions xla/ffi/attribute_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#include "xla/ffi/attribute_map.h"

#include <cstdint>
#include <memory>
#include <string_view>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -27,122 +29,123 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "xla/ffi/call_frame.h"
#include "tsl/platform/errors.h"

using FlatAttribute = xla::ffi::CallFrameBuilder::FlatAttribute;
using FlatAttributesMap = xla::ffi::CallFrameBuilder::FlatAttributesMap;
#include "tsl/platform/statusor.h"

namespace xla::ffi {

absl::StatusOr<FlatAttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict) {
FlatAttributesMap attributes;
for (auto& kv : dict) {
std::string_view name = kv.getName().strref();
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertBoolAttr(
std::string_view name, mlir::BoolAttr boolean) {
return static_cast<bool>(boolean.getValue());
}

auto boolean = [&](mlir::BoolAttr boolean) {
attributes[name] = static_cast<bool>(boolean.getValue());
return absl::OkStatus();
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertStringAttr(
std::string_view name, mlir::StringAttr str) {
return str.getValue().str();
}

auto integer = [&](mlir::IntegerAttr integer) {
if (integer.getType().isUnsignedInteger()) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
attributes[name] = static_cast<uint8_t>(integer.getUInt());
return absl::OkStatus();
case 16:
attributes[name] = static_cast<uint16_t>(integer.getUInt());
return absl::OkStatus();
case 32:
attributes[name] = static_cast<uint32_t>(integer.getUInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<uint64_t>(integer.getUInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ",
name));
}
} else {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
attributes[name] = static_cast<int8_t>(integer.getInt());
return absl::OkStatus();
case 16:
attributes[name] = static_cast<int16_t>(integer.getInt());
return absl::OkStatus();
case 32:
attributes[name] = static_cast<int32_t>(integer.getInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<int64_t>(integer.getInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ",
name));
}
}
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertIntegerAttr(
std::string_view name, mlir::IntegerAttr integer) {
if (integer.getType().isUnsignedInteger()) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
return static_cast<uint8_t>(integer.getUInt());
case 16:
return static_cast<uint16_t>(integer.getUInt());
case 32:
return static_cast<uint32_t>(integer.getUInt());
case 64:
return static_cast<uint64_t>(integer.getUInt());
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ", name));
}
} else {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
return static_cast<int8_t>(integer.getInt());
case 16:
return static_cast<int16_t>(integer.getInt());
case 32:
return static_cast<int32_t>(integer.getInt());
case 64:
return static_cast<int64_t>(integer.getInt());
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ", name));
}
}
}

auto fp = [&](mlir::FloatAttr fp) {
switch (fp.getType().getIntOrFloatBitWidth()) {
case 32:
attributes[name] = static_cast<float>(fp.getValue().convertToFloat());
return absl::OkStatus();
case 64:
attributes[name] =
static_cast<double>(fp.getValue().convertToDouble());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported float attribute bit width for attribute: ", name));
}
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertFloatAttr(
std::string_view name, mlir::FloatAttr fp) {
switch (fp.getType().getIntOrFloatBitWidth()) {
case 32:
return static_cast<float>(fp.getValue().convertToFloat());
case 64:
return static_cast<double>(fp.getValue().convertToDouble());
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported float attribute bit width for attribute: ", name));
}
}

auto arr = [&](mlir::DenseArrayAttr arr) {
if (auto dense = mlir::dyn_cast<mlir::DenseI8ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else {
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported array element type for attribute: ", name));
}
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertArrayAttr(
std::string_view name, mlir::DenseArrayAttr arr) {
if (auto dense = mlir::dyn_cast<mlir::DenseI8ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported array element type for attribute: ", name));
}
}

auto str = [&](mlir::StringAttr str) {
attributes[name] = str.getValue().str();
return absl::OkStatus();
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDictionaryAttr(
std::string_view name, mlir::DictionaryAttr dict) {
TF_ASSIGN_OR_RETURN(auto attrs, BuildAttributesMap(dict));
return CallFrameBuilder::Dictionary{
std::make_shared<CallFrameBuilder::AttributesMap>(std::move(attrs))};
}

absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict) {
CallFrameBuilder::AttributesMap attributes;
for (auto& kv : dict) {
std::string_view name = kv.getName().strref();
mlir::Attribute value = kv.getValue();

// Wraps attribute conversion function into callable object.
auto convert_with = [&](auto converter_fn) {
return [&, fn = converter_fn](auto attr) -> absl::Status {
TF_ASSIGN_OR_RETURN(attributes[name], fn(name, attr));
return absl::OkStatus();
};
};

TF_RETURN_IF_ERROR(
llvm::TypeSwitch<mlir::Attribute, absl::Status>(kv.getValue())
.Case<mlir::BoolAttr>(boolean)
.Case<mlir::IntegerAttr>(integer)
.Case<mlir::FloatAttr>(fp)
.Case<mlir::DenseArrayAttr>(arr)
.Case<mlir::StringAttr>(str)
llvm::TypeSwitch<mlir::Attribute, absl::Status>(value)
.Case<mlir::BoolAttr>(convert_with(ConvertBoolAttr))
.Case<mlir::IntegerAttr>(convert_with(ConvertIntegerAttr))
.Case<mlir::FloatAttr>(convert_with(ConvertFloatAttr))
.Case<mlir::DenseArrayAttr>(convert_with(ConvertArrayAttr))
.Case<mlir::StringAttr>(convert_with(ConvertStringAttr))
.Case<mlir::DictionaryAttr>(convert_with(ConvertDictionaryAttr))
.Default([&](mlir::Attribute) {
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported attribute type for attribute: ", name));
}));
}

return attributes;
}

} // namespace xla::ffi
2 changes: 1 addition & 1 deletion xla/ffi/attribute_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace xla::ffi {

// Converts MLIR dictionary attribute attached to a custom call operation to a
// custom call handler attributes that are forwarded to the FFI handler.
absl::StatusOr<CallFrameBuilder::FlatAttributesMap> BuildAttributesMap(
absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict);

} // namespace xla::ffi
Expand Down
15 changes: 6 additions & 9 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,17 @@ CallFrameBuilder::AttributesBuilder::AttributesBuilder() = default;
CallFrameBuilder::AttributesBuilder::~AttributesBuilder() = default;

void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
FlatAttribute attr) {
attrs_.try_emplace(std::move(name), FromFlatAttribute(std::move(attr)));
Attribute attr) {
attrs_.try_emplace(std::move(name), std::move(attr));
}

void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
FlatAttributesMap attrs) {
AttributesBuilder builder;
for (auto& [name, attr] : attrs) builder.Insert(name, std::move(attr));

auto attrs_map = std::make_unique<AttributesMap>(builder.Build());
attrs_.try_emplace(std::move(name), Dictionary{std::move(attrs_map)});
AttributesMap attrs) {
attrs_.try_emplace(std::move(name),
Dictionary{std::make_shared<AttributesMap>(attrs)});
}

void CallFrameBuilder::AttributesBuilder::Append(FlatAttributesMap attrs) {
void CallFrameBuilder::AttributesBuilder::Append(AttributesMap attrs) {
for (auto& [name, attr] : attrs) Insert(name, std::move(attr));
}

Expand Down
15 changes: 8 additions & 7 deletions xla/ffi/call_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ class CallFrameBuilder {
using AttributesMap = absl::flat_hash_map<std::string, Attribute>;

// Dictionary is just a wrapper around AttributesMap. We need an indirection
// through `std::unique_ptr` to be able to define recursive `std::variant`.
// through `std::shared_ptr` to be able to define recursive `std::variant`. We
// use shared pointer to keep `AttributesMap` copyable.
struct Dictionary {
std::unique_ptr<AttributesMap> attrs;
std::shared_ptr<AttributesMap> attrs;
};

// A helper class to build call frame attributes.
Expand All @@ -92,14 +93,14 @@ class CallFrameBuilder {
AttributesBuilder();
~AttributesBuilder();

void Insert(std::string name, Attribute attr);
void Insert(std::string name, AttributesMap attrs);
void Append(AttributesMap attrs);

// This overload is only necessary to support older GCC versions.
void Insert(std::string name, const char* attr) {
Insert(std::move(name), std::string(attr));
Insert(std::move(name), Attribute{std::string(attr)});
}
void Insert(std::string name, FlatAttribute attr);
void Insert(std::string name, FlatAttributesMap attrs);

void Append(FlatAttributesMap attrs);

AttributesMap Build();

Expand Down
6 changes: 3 additions & 3 deletions xla/ffi/call_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ void BM_AddBufferArg(benchmark::State& state) {
void BM_AddAttributes(benchmark::State& state) {
size_t num_attrs = state.range(0);

CallFrameBuilder::FlatAttributesMap flat_attrs;
CallFrameBuilder::AttributesMap attrs;
for (size_t i = 0; i < num_attrs; ++i) {
flat_attrs.try_emplace(absl::StrCat("attr_", i), 42);
attrs.try_emplace(absl::StrCat("attr_", i), 42);
}

for (auto _ : state) {
CallFrameBuilder::AttributesBuilder attrs_builder;
attrs_builder.Append(flat_attrs);
attrs_builder.Append(attrs);

CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
builder.AddAttributes(attrs_builder.Build());
Expand Down
Loading
Loading