Skip to content

Commit

Permalink
[IFRT] Use AttributeMap wherever applicable in IFRT
Browse files Browse the repository at this point in the history
This change uses `xla::ifrt::AttributeMap` in IFRT `Client`, `Device`, and
`Topology` attributes. It adds utility functions to convert from/to a map of
`PjRtValueType` that is used in PjRt and was used in these IFRT classes.

IFRT Proxy also uses a native serialization of `xla::ifrt::AttributeMap` in
Version 4 with an added compatibility for previously supported protocol
versions (the minimum version does not change).

PiperOrigin-RevId: 644184301
  • Loading branch information
hyeontaek authored and copybara-github committed Jun 28, 2024
1 parent ff21834 commit c5a9030
Show file tree
Hide file tree
Showing 34 changed files with 356 additions and 220 deletions.
2 changes: 2 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,12 @@ cc_library(
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/python/ifrt",
"//xla/python/ifrt:attribute_map",
"//xla/python/ifrt:plugin_program",
"//xla/python/ifrt:plugin_program_serdes",
"//xla/python/ifrt/hlo:hlo_program",
"//xla/python/pjrt_ifrt",
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"//xla/python/pjrt_ifrt:xla_host_callback_proto_cc",
"//xla/python/pjrt_ifrt:xla_ifrt",
"//xla/service:computation_placer_hdr",
Expand Down
3 changes: 2 additions & 1 deletion xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cc_library(
compatible_with = get_compatible_with_portable(),
deps = [
":array_spec_proto_cc",
":attribute_map",
":device_proto_cc",
":dtype_proto_cc",
":remap_plan_proto_cc",
Expand Down Expand Up @@ -368,10 +369,10 @@ cc_library(
srcs = ["mock.cc"],
hdrs = ["mock.h"],
deps = [
":attribute_map",
":ifrt",
"//xla:test",
"//xla/hlo/ir:hlo",
"//xla/pjrt:pjrt_device_description",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_layout",
"//xla/tsl/concurrency:ref_count",
Expand Down
8 changes: 2 additions & 6 deletions xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/dtype.h"
Expand Down Expand Up @@ -198,9 +196,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
// * supports_executable_serialization (bool; default = true): Whether IFRT
// executables produced by this client are serializable. If false, all
// executables from this client are considered not serializable.
using ClientAttribute = xla::PjRtValueType;
virtual absl::flat_hash_map<std::string, ClientAttribute> attributes()
const = 0;
virtual const AttributeMap& Attributes() const = 0;

virtual int device_count() const = 0;
virtual int addressable_device_count() const = 0;
Expand Down
6 changes: 2 additions & 4 deletions xla/python/ifrt/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/device.pb.h"
#include "tsl/lib/gtl/int_type.h"

Expand Down Expand Up @@ -64,8 +63,7 @@ class Device : public llvm::RTTIExtends<Device, llvm::RTTIRoot> {
// Returns vendor specific attributes about the device. For example the model
// number of a GPU, or the mesh coordinates of a TPU device. The returned
// reference will remain valid for the lifetime of the Device.
virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
Attributes() const = 0;
virtual const AttributeMap& Attributes() const = 0;

// A vendor-dependent string that uniquely identifies the kind of device,
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
Expand Down
18 changes: 6 additions & 12 deletions xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,13 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include <gmock/gmock.h>
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/dtype.h"
Expand Down Expand Up @@ -154,8 +151,8 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
ON_CALL(*this, platform_id).WillByDefault([this]() {
return delegated_->platform_id();
});
ON_CALL(*this, attributes).WillByDefault([this]() {
return delegated_->attributes();
ON_CALL(*this, Attributes).WillByDefault([this]() {
return delegated_->Attributes();
});
ON_CALL(*this, device_count).WillByDefault([this]() {
return delegated_->device_count();
Expand Down Expand Up @@ -219,12 +216,9 @@ MockDevice::MockDevice(Device* delegated) : delegated_(delegated) {
ON_CALL(*this, ToString).WillByDefault([this]() {
return delegated_->ToString();
});
ON_CALL(*this, Attributes)
.WillByDefault(
[this]()
-> const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& {
return delegated_->Attributes();
});
ON_CALL(*this, Attributes).WillByDefault([this]() -> const AttributeMap& {
return delegated_->Attributes();
});
ON_CALL(*this, DefaultMemory).WillByDefault([this]() {
return delegated_->DefaultMemory();
});
Expand Down
9 changes: 3 additions & 6 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
Expand Down Expand Up @@ -137,8 +137,7 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
MOCK_METHOD(absl::string_view, runtime_type, (), (const, final));
MOCK_METHOD(absl::string_view, platform_name, (), (const, final));
MOCK_METHOD(absl::string_view, platform_version, (), (const, final));
MOCK_METHOD((absl::flat_hash_map<std::string, Client::ClientAttribute>),
attributes, (), (const, final));
MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final));
MOCK_METHOD(int, device_count, (), (const, final));
MOCK_METHOD(PlatformId, platform_id, (), (const, final));
MOCK_METHOD(int, addressable_device_count, (), (const, final));
Expand Down Expand Up @@ -206,9 +205,7 @@ class MockDevice : public Device {
MOCK_METHOD(absl::string_view, Kind, (), (const, final));
MOCK_METHOD(absl::string_view, DebugString, (), (const, final));
MOCK_METHOD(absl::string_view, ToString, (), (const, final));
MOCK_METHOD(
(const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&),
Attributes, (), (const, final));
MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final));
MOCK_METHOD(absl::StatusOr<Memory*>, DefaultMemory, (), (const, final));
MOCK_METHOD(absl::Span<Memory* const>, Memories, (), (const, final));
// LINT.ThenChange(mock.cc:MockDeviceDelegation)
Expand Down
5 changes: 2 additions & 3 deletions xla/python/ifrt/topology.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/layout.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/xla_data.pb.h"

namespace xla::ifrt {
Expand Down Expand Up @@ -64,8 +64,7 @@ class Topology : public llvm::RTTIExtends<Topology, llvm::RTTIRoot> {
virtual absl::StatusOr<std::string> Serialize() const = 0;

// Returns vendor specific attributes about the topology.
virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
Attributes() const = 0;
virtual const AttributeMap& Attributes() const = 0;

static char ID; // NOLINT
};
Expand Down
6 changes: 5 additions & 1 deletion xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ cc_library(
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_device_description",
"//xla/python/ifrt",
"//xla/python/ifrt:attribute_map",
"//xla/python/ifrt_proxy/common:common_serdes",
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"//xla/python/ifrt_proxy/common:types",
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"//xla/tsl/concurrency:ref_count",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down Expand Up @@ -151,8 +153,8 @@ ifrt_proxy_cc_test(
":mock_host_buffer",
":rpc_helper",
":version",
"//xla/pjrt:pjrt_device_description",
"//xla/python/ifrt",
"//xla/python/ifrt:attribute_map",
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"//xla/service:computation_placer_hdr",
"@com_google_absl//absl/status",
Expand All @@ -174,6 +176,8 @@ cc_library(
deps = [
"//xla/pjrt:pjrt_device_description",
"//xla/python/ifrt",
"//xla/python/ifrt:attribute_map",
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
24 changes: 18 additions & 6 deletions xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Support/Casting.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/dtype.h"
Expand All @@ -48,6 +49,7 @@
#include "xla/python/ifrt_proxy/client/rpc_helper.h"
#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h"
#include "xla/python/ifrt_proxy/common/types.h"
#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
Expand Down Expand Up @@ -78,16 +80,24 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
std::vector<xla::ifrt::Device*> addressable_device_ptrs;

for (const auto& d : init_response.devices()) {
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes;
for (const auto& [key, attr] : d.attributes()) {
TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value,
FromVariantProto(attr));
attributes.insert({key, std::move(value)});
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>
pjrt_device_attributes;
AttributeMap::Map attributes;
if (rpc_helper->version().protocol_version() <= 3) {
for (const auto& [key, attr] : d.deprecated_attributes()) {
TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value,
FromVariantProto(attr));
pjrt_device_attributes.insert({key, std::move(value)});
}
} else {
TF_ASSIGN_OR_RETURN(auto attributes,
AttributeMap::FromProto(d.attributes()));
pjrt_device_attributes = ToPjRtDeviceAttributeMap(std::move(attributes));
}

DeviceDescription desc(d.id(), init_response.process_index(),
d.device_kind(), d.debug_string(), d.to_string(),
std::move(attributes));
std::move(pjrt_device_attributes));
bool is_addressable = addressable_device_ids.contains(d.id());

auto device =
Expand Down Expand Up @@ -162,6 +172,8 @@ Client::Client(std::shared_ptr<RpcHelper> rpc_helper, uint64_t session_id,
platform_id_(platform_id),
process_index_(process_index),
runtime_type_(std::move(runtime_type)),
// TODO(b/309059940): Forward the backend attributes to the client.
attributes_(AttributeMap::Map()),
devices_(std::move(devices)),
device_ptrs_(device_ptrs),
addressable_device_ptrs_(std::move(addressable_device_ptrs)),
Expand Down
9 changes: 4 additions & 5 deletions xla/python/ifrt_proxy/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
Expand Down Expand Up @@ -101,11 +102,7 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
return platform_version_;
}
PlatformId platform_id() const override { return platform_id_; }
absl::flat_hash_map<std::string, ClientAttribute> attributes()
const override {
// TODO(b/309059940): Forward the backend attributes to the client.
return {};
}
const AttributeMap& Attributes() const override { return attributes_; }
int device_count() const override { return devices().size(); }
int addressable_device_count() const override {
return addressable_devices().size();
Expand Down Expand Up @@ -164,6 +161,8 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
const uint64_t process_index_;
const std::string runtime_type_;

const AttributeMap attributes_;

const absl::flat_hash_map<int, std::unique_ptr<Device>> devices_;
const std::vector<xla::ifrt::Device*> device_ptrs_;
const std::vector<xla::ifrt::Device*> addressable_device_ptrs_;
Expand Down
Loading

0 comments on commit c5a9030

Please sign in to comment.