Skip to content

Commit

Permalink
[xla:ffi] Add support for encoding mlir::DictionaryAttr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679302245
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Sep 26, 2024
1 parent 95beb0e commit 3df9bd2
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 136 deletions.
4 changes: 2 additions & 2 deletions docs/custom_call.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ struct Range {
int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
StructMember<int64_t>("i64"));
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ limitations under the License.
namespace xla::cpu {
namespace {

using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap;
using AttributesMap = ffi::CallFrameBuilder::AttributesMap;

absl::StatusOr<AttributesMap> ParseAttributes(
absl::string_view backend_config) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,15 @@ cc_library(
hdrs = ["attribute_map.h"],
deps = [
":call_frame",
"//xla:util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
8 changes: 4 additions & 4 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,10 @@ TEST(FfiTest, AttrsAsDictionary) {
}

TEST(FfiTest, DictionaryAttr) {
CallFrameBuilder::FlatAttributesMap dict0;
CallFrameBuilder::AttributesMap dict0;
dict0.try_emplace("i32", 42);

CallFrameBuilder::FlatAttributesMap dict1;
CallFrameBuilder::AttributesMap dict1;
dict1.try_emplace("f32", 42.0f);

CallFrameBuilder::AttributesBuilder attrs;
Expand Down Expand Up @@ -864,7 +864,7 @@ TEST(FfiTest, DictionaryAttr) {
}

TEST(FfiTest, StructAttr) {
CallFrameBuilder::FlatAttributesMap dict;
CallFrameBuilder::AttributesMap dict;
dict.try_emplace("i32", 42);
dict.try_emplace("f32", 42.0f);

Expand Down Expand Up @@ -977,7 +977,7 @@ TEST(FfiTest, EnumAttr) {
}

TEST(FfiTest, WrongEnumAttrType) {
CallFrameBuilder::FlatAttributesMap dict;
CallFrameBuilder::AttributesMap dict;
dict.try_emplace("i32", 42);

CallFrameBuilder::AttributesBuilder attrs;
Expand Down
203 changes: 103 additions & 100 deletions xla/ffi/attribute_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#include "xla/ffi/attribute_map.h"

#include <cstdint>
#include <memory>
#include <string_view>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -27,122 +29,123 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "xla/ffi/call_frame.h"
#include "tsl/platform/errors.h"

using FlatAttribute = xla::ffi::CallFrameBuilder::FlatAttribute;
using FlatAttributesMap = xla::ffi::CallFrameBuilder::FlatAttributesMap;
#include "tsl/platform/statusor.h"

namespace xla::ffi {

absl::StatusOr<FlatAttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict) {
FlatAttributesMap attributes;
for (auto& kv : dict) {
std::string_view name = kv.getName().strref();
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertBoolAttr(
std::string_view name, mlir::BoolAttr boolean) {
return static_cast<bool>(boolean.getValue());
}

auto boolean = [&](mlir::BoolAttr boolean) {
attributes[name] = static_cast<bool>(boolean.getValue());
return absl::OkStatus();
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertStringAttr(
std::string_view name, mlir::StringAttr str) {
return str.getValue().str();
}

auto integer = [&](mlir::IntegerAttr integer) {
if (integer.getType().isUnsignedInteger()) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
attributes[name] = static_cast<uint8_t>(integer.getUInt());
return absl::OkStatus();
case 16:
attributes[name] = static_cast<uint16_t>(integer.getUInt());
return absl::OkStatus();
case 32:
attributes[name] = static_cast<uint32_t>(integer.getUInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<uint64_t>(integer.getUInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ",
name));
}
} else {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
attributes[name] = static_cast<int8_t>(integer.getInt());
return absl::OkStatus();
case 16:
attributes[name] = static_cast<int16_t>(integer.getInt());
return absl::OkStatus();
case 32:
attributes[name] = static_cast<int32_t>(integer.getInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<int64_t>(integer.getInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ",
name));
}
}
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertIntegerAttr(
std::string_view name, mlir::IntegerAttr integer) {
if (integer.getType().isUnsignedInteger()) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
return static_cast<uint8_t>(integer.getUInt());
case 16:
return static_cast<uint16_t>(integer.getUInt());
case 32:
return static_cast<uint32_t>(integer.getUInt());
case 64:
return static_cast<uint64_t>(integer.getUInt());
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ", name));
}
} else {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
return static_cast<int8_t>(integer.getInt());
case 16:
return static_cast<int16_t>(integer.getInt());
case 32:
return static_cast<int32_t>(integer.getInt());
case 64:
return static_cast<int64_t>(integer.getInt());
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ", name));
}
}
}

auto fp = [&](mlir::FloatAttr fp) {
switch (fp.getType().getIntOrFloatBitWidth()) {
case 32:
attributes[name] = static_cast<float>(fp.getValue().convertToFloat());
return absl::OkStatus();
case 64:
attributes[name] =
static_cast<double>(fp.getValue().convertToDouble());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported float attribute bit width for attribute: ", name));
}
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertFloatAttr(
std::string_view name, mlir::FloatAttr fp) {
switch (fp.getType().getIntOrFloatBitWidth()) {
case 32:
return static_cast<float>(fp.getValue().convertToFloat());
case 64:
return static_cast<double>(fp.getValue().convertToDouble());
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported float attribute bit width for attribute: ", name));
}
}

auto arr = [&](mlir::DenseArrayAttr arr) {
if (auto dense = mlir::dyn_cast<mlir::DenseI8ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
attributes[name] = dense.asArrayRef().vec();
return absl::OkStatus();
} else {
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported array element type for attribute: ", name));
}
};
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertArrayAttr(
std::string_view name, mlir::DenseArrayAttr arr) {
if (auto dense = mlir::dyn_cast<mlir::DenseI8ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported array element type for attribute: ", name));
}
}

auto str = [&](mlir::StringAttr str) {
attributes[name] = str.getValue().str();
return absl::OkStatus();
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDictionaryAttr(
std::string_view name, mlir::DictionaryAttr dict) {
TF_ASSIGN_OR_RETURN(auto attrs, BuildAttributesMap(dict));
return CallFrameBuilder::Dictionary{
std::make_shared<CallFrameBuilder::AttributesMap>(std::move(attrs))};
}

absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict) {
CallFrameBuilder::AttributesMap attributes;
for (auto& kv : dict) {
std::string_view name = kv.getName().strref();
mlir::Attribute value = kv.getValue();

// Wraps attribute conversion function into callable object.
auto convert_with = [&](auto converter_fn) {
return [&, fn = converter_fn](auto attr) -> absl::Status {
TF_ASSIGN_OR_RETURN(attributes[name], fn(name, attr));
return absl::OkStatus();
};
};

TF_RETURN_IF_ERROR(
llvm::TypeSwitch<mlir::Attribute, absl::Status>(kv.getValue())
.Case<mlir::BoolAttr>(boolean)
.Case<mlir::IntegerAttr>(integer)
.Case<mlir::FloatAttr>(fp)
.Case<mlir::DenseArrayAttr>(arr)
.Case<mlir::StringAttr>(str)
llvm::TypeSwitch<mlir::Attribute, absl::Status>(value)
.Case<mlir::BoolAttr>(convert_with(ConvertBoolAttr))
.Case<mlir::IntegerAttr>(convert_with(ConvertIntegerAttr))
.Case<mlir::FloatAttr>(convert_with(ConvertFloatAttr))
.Case<mlir::DenseArrayAttr>(convert_with(ConvertArrayAttr))
.Case<mlir::StringAttr>(convert_with(ConvertStringAttr))
.Case<mlir::DictionaryAttr>(convert_with(ConvertDictionaryAttr))
.Default([&](mlir::Attribute) {
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported attribute type for attribute: ", name));
}));
}

return attributes;
}

} // namespace xla::ffi
2 changes: 1 addition & 1 deletion xla/ffi/attribute_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace xla::ffi {

// Converts MLIR dictionary attribute attached to a custom call operation to a
// custom call handler attributes that are forwarded to the FFI handler.
absl::StatusOr<CallFrameBuilder::FlatAttributesMap> BuildAttributesMap(
absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict);

} // namespace xla::ffi
Expand Down
15 changes: 6 additions & 9 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,17 @@ CallFrameBuilder::AttributesBuilder::AttributesBuilder() = default;
CallFrameBuilder::AttributesBuilder::~AttributesBuilder() = default;

void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
FlatAttribute attr) {
attrs_.try_emplace(std::move(name), FromFlatAttribute(std::move(attr)));
Attribute attr) {
attrs_.try_emplace(std::move(name), std::move(attr));
}

void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
FlatAttributesMap attrs) {
AttributesBuilder builder;
for (auto& [name, attr] : attrs) builder.Insert(name, std::move(attr));

auto attrs_map = std::make_unique<AttributesMap>(builder.Build());
attrs_.try_emplace(std::move(name), Dictionary{std::move(attrs_map)});
AttributesMap attrs) {
attrs_.try_emplace(std::move(name),
Dictionary{std::make_shared<AttributesMap>(attrs)});
}

void CallFrameBuilder::AttributesBuilder::Append(FlatAttributesMap attrs) {
void CallFrameBuilder::AttributesBuilder::Append(AttributesMap attrs) {
for (auto& [name, attr] : attrs) Insert(name, std::move(attr));
}

Expand Down
15 changes: 8 additions & 7 deletions xla/ffi/call_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ class CallFrameBuilder {
using AttributesMap = absl::flat_hash_map<std::string, Attribute>;

// Dictionary is just a wrapper around AttributesMap. We need an indirection
// through `std::unique_ptr` to be able to define recursive `std::variant`.
// through `std::shared_ptr` to be able to define recursive `std::variant`. We
// use shared pointer to keep `AttributesMap` copyable.
struct Dictionary {
std::unique_ptr<AttributesMap> attrs;
std::shared_ptr<AttributesMap> attrs;
};

// A helper class to build call frame attributes.
Expand All @@ -92,14 +93,14 @@ class CallFrameBuilder {
AttributesBuilder();
~AttributesBuilder();

void Insert(std::string name, Attribute attr);
void Insert(std::string name, AttributesMap attrs);
void Append(AttributesMap attrs);

// This overload is only necessary to support older GCC versions.
void Insert(std::string name, const char* attr) {
Insert(std::move(name), std::string(attr));
Insert(std::move(name), Attribute{std::string(attr)});
}
void Insert(std::string name, FlatAttribute attr);
void Insert(std::string name, FlatAttributesMap attrs);

void Append(FlatAttributesMap attrs);

AttributesMap Build();

Expand Down
6 changes: 3 additions & 3 deletions xla/ffi/call_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ void BM_AddBufferArg(benchmark::State& state) {
void BM_AddAttributes(benchmark::State& state) {
size_t num_attrs = state.range(0);

CallFrameBuilder::FlatAttributesMap flat_attrs;
CallFrameBuilder::AttributesMap attrs;
for (size_t i = 0; i < num_attrs; ++i) {
flat_attrs.try_emplace(absl::StrCat("attr_", i), 42);
attrs.try_emplace(absl::StrCat("attr_", i), 42);
}

for (auto _ : state) {
CallFrameBuilder::AttributesBuilder attrs_builder;
attrs_builder.Append(flat_attrs);
attrs_builder.Append(attrs);

CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
builder.AddAttributes(attrs_builder.Build());
Expand Down
Loading

0 comments on commit 3df9bd2

Please sign in to comment.