diff --git a/xla/python/BUILD b/xla/python/BUILD index f77b00016d6207..23e5fda3d78e1f 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -395,10 +395,12 @@ cc_library( "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", "//xla/python/ifrt:plugin_program", "//xla/python/ifrt:plugin_program_serdes", "//xla/python/ifrt/hlo:hlo_program", "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index 8ad430cabe5061..4e8c0aa324f3c7 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -82,6 +82,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":array_spec_proto_cc", + ":attribute_map", ":device_proto_cc", ":dtype_proto_cc", ":remap_plan_proto_cc", @@ -368,10 +369,10 @@ cc_library( srcs = ["mock.cc"], hdrs = ["mock.h"], deps = [ + ":attribute_map", ":ifrt", "//xla:test", "//xla/hlo/ir:hlo", - "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_layout", "//xla/tsl/concurrency:ref_count", diff --git a/xla/python/ifrt/client.h b/xla/python/ifrt/client.h index 3883cb0af35fdd..5f75c196d159c7 100644 --- a/xla/python/ifrt/client.h +++ b/xla/python/ifrt/client.h @@ -20,19 +20,17 @@ limitations under the License. #include #include #include -#include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" @@ -198,9 +196,7 @@ class Client : public llvm::RTTIExtends { // * supports_executable_serialization (bool; default = true): Whether IFRT // executables produced by this client are serializable. If false, all // executables from this client are considered not serializable. - using ClientAttribute = xla::PjRtValueType; - virtual absl::flat_hash_map attributes() - const = 0; + virtual const AttributeMap& Attributes() const = 0; virtual int device_count() const = 0; virtual int addressable_device_count() const = 0; diff --git a/xla/python/ifrt/device.h b/xla/python/ifrt/device.h index a93b60454b4337..d1fec0784545a0 100644 --- a/xla/python/ifrt/device.h +++ b/xla/python/ifrt/device.h @@ -24,14 +24,13 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" -#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.pb.h" #include "tsl/lib/gtl/int_type.h" @@ -64,8 +63,7 @@ class Device : public llvm::RTTIExtends { // Returns vendor specific attributes about the device. For example the model // number of a GPU, or the mesh coordinates of a TPU device. The returned // reference will remain valid for the lifetime of the Device. - virtual const absl::flat_hash_map& - Attributes() const = 0; + virtual const AttributeMap& Attributes() const = 0; // A vendor-dependent string that uniquely identifies the kind of device, // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are diff --git a/xla/python/ifrt/mock.cc b/xla/python/ifrt/mock.cc index 2972adad4d0186..0a5ff16d69a8d5 100644 --- a/xla/python/ifrt/mock.cc +++ b/xla/python/ifrt/mock.cc @@ -19,16 +19,13 @@ limitations under the License. #include #include #include -#include -#include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" @@ -154,8 +151,8 @@ MockClient::MockClient(std::unique_ptr delegated) ON_CALL(*this, platform_id).WillByDefault([this]() { return delegated_->platform_id(); }); - ON_CALL(*this, attributes).WillByDefault([this]() { - return delegated_->attributes(); + ON_CALL(*this, Attributes).WillByDefault([this]() { + return delegated_->Attributes(); }); ON_CALL(*this, device_count).WillByDefault([this]() { return delegated_->device_count(); @@ -219,12 +216,9 @@ MockDevice::MockDevice(Device* delegated) : delegated_(delegated) { ON_CALL(*this, ToString).WillByDefault([this]() { return delegated_->ToString(); }); - ON_CALL(*this, Attributes) - .WillByDefault( - [this]() - -> const absl::flat_hash_map& { - return delegated_->Attributes(); - }); + ON_CALL(*this, Attributes).WillByDefault([this]() -> const AttributeMap& { + return delegated_->Attributes(); + }); ON_CALL(*this, DefaultMemory).WillByDefault([this]() { return delegated_->DefaultMemory(); }); diff --git a/xla/python/ifrt/mock.h b/xla/python/ifrt/mock.h index ae1810f6fb3a82..dd89cd63b6a70f 100644 --- a/xla/python/ifrt/mock.h +++ b/xla/python/ifrt/mock.h @@ -30,10 +30,10 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -137,8 +137,7 @@ class MockClient : public llvm::RTTIExtends { MOCK_METHOD(absl::string_view, runtime_type, (), (const, final)); MOCK_METHOD(absl::string_view, platform_name, (), (const, final)); MOCK_METHOD(absl::string_view, platform_version, (), (const, final)); - MOCK_METHOD((absl::flat_hash_map), - attributes, (), (const, final)); + MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final)); MOCK_METHOD(int, device_count, (), (const, final)); MOCK_METHOD(PlatformId, platform_id, (), (const, final)); MOCK_METHOD(int, addressable_device_count, (), (const, final)); @@ -206,9 +205,7 @@ class MockDevice : public Device { MOCK_METHOD(absl::string_view, Kind, (), (const, final)); MOCK_METHOD(absl::string_view, DebugString, (), (const, final)); MOCK_METHOD(absl::string_view, ToString, (), (const, final)); - MOCK_METHOD( - (const absl::flat_hash_map&), - Attributes, (), (const, final)); + MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final)); MOCK_METHOD(absl::StatusOr, DefaultMemory, (), (const, final)); MOCK_METHOD(absl::Span, Memories, (), (const, final)); // LINT.ThenChange(mock.cc:MockDeviceDelegation) diff --git a/xla/python/ifrt/topology.h b/xla/python/ifrt/topology.h index 3926b98f16fcd1..8d1104aca01f33 100644 --- a/xla/python/ifrt/topology.h +++ b/xla/python/ifrt/topology.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -29,6 +28,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/xla_data.pb.h" namespace xla::ifrt { @@ -64,8 +64,7 @@ class Topology : public llvm::RTTIExtends { virtual absl::StatusOr Serialize() const = 0; // Returns vendor specific attributes about the topology. - virtual const absl::flat_hash_map& - Attributes() const = 0; + virtual const AttributeMap& Attributes() const = 0; static char ID; // NOLINT }; diff --git a/xla/python/ifrt_proxy/client/BUILD b/xla/python/ifrt_proxy/client/BUILD index 1751443b86d9c7..eb31ea2a9f8e56 100644 --- a/xla/python/ifrt_proxy/client/BUILD +++ b/xla/python/ifrt_proxy/client/BUILD @@ -121,9 +121,11 @@ cc_library( "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_device_description", "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", "//xla/python/ifrt_proxy/common:common_serdes", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "//xla/python/ifrt_proxy/common:types", + "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -151,8 +153,8 @@ ifrt_proxy_cc_test( ":mock_host_buffer", ":rpc_helper", ":version", - "//xla/pjrt:pjrt_device_description", "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "//xla/service:computation_placer_hdr", "@com_google_absl//absl/status", @@ -174,6 +176,8 @@ cc_library( deps = [ "//xla/pjrt:pjrt_device_description", "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", + "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xla/python/ifrt_proxy/client/client.cc b/xla/python/ifrt_proxy/client/client.cc index 42c83556440129..faf868bd2ea168 100644 --- a/xla/python/ifrt_proxy/client/client.cc +++ b/xla/python/ifrt_proxy/client/client.cc @@ -33,6 +33,7 @@ #include "llvm/Support/Casting.h" #include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" @@ -48,6 +49,7 @@ #include "xla/python/ifrt_proxy/client/rpc_helper.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" #include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -78,16 +80,24 @@ absl::StatusOr> Client::Create( std::vector addressable_device_ptrs; for (const auto& d : init_response.devices()) { - absl::flat_hash_map attributes; - for (const auto& [key, attr] : d.attributes()) { - TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value, - FromVariantProto(attr)); - attributes.insert({key, std::move(value)}); + absl::flat_hash_map + pjrt_device_attributes; + AttributeMap::Map attributes; + if (rpc_helper->version().protocol_version() <= 3) { + for (const auto& [key, attr] : d.deprecated_attributes()) { + TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value, + FromVariantProto(attr)); + pjrt_device_attributes.insert({key, std::move(value)}); + } + } else { + TF_ASSIGN_OR_RETURN(auto attributes, + AttributeMap::FromProto(d.attributes())); + pjrt_device_attributes = ToPjRtDeviceAttributeMap(std::move(attributes)); } DeviceDescription desc(d.id(), init_response.process_index(), d.device_kind(), d.debug_string(), d.to_string(), - std::move(attributes)); + std::move(pjrt_device_attributes)); bool is_addressable = addressable_device_ids.contains(d.id()); auto device = @@ -162,6 +172,8 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, platform_id_(platform_id), process_index_(process_index), runtime_type_(std::move(runtime_type)), + // TODO(b/309059940): Forward the backend attributes to the client. + attributes_(AttributeMap::Map()), devices_(std::move(devices)), device_ptrs_(device_ptrs), addressable_device_ptrs_(std::move(addressable_device_ptrs)), diff --git a/xla/python/ifrt_proxy/client/client.h b/xla/python/ifrt_proxy/client/client.h index 00939710df1bac..1cfa791d1b7c5a 100644 --- a/xla/python/ifrt_proxy/client/client.h +++ b/xla/python/ifrt_proxy/client/client.h @@ -33,6 +33,7 @@ #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -101,11 +102,7 @@ class Client final : public llvm::RTTIExtends { return platform_version_; } PlatformId platform_id() const override { return platform_id_; } - absl::flat_hash_map attributes() - const override { - // TODO(b/309059940): Forward the backend attributes to the client. - return {}; - } + const AttributeMap& Attributes() const override { return attributes_; } int device_count() const override { return devices().size(); } int addressable_device_count() const override { return addressable_devices().size(); @@ -164,6 +161,8 @@ class Client final : public llvm::RTTIExtends { const uint64_t process_index_; const std::string runtime_type_; + const AttributeMap attributes_; + const absl::flat_hash_map> devices_; const std::vector device_ptrs_; const std::vector addressable_device_ptrs_; diff --git a/xla/python/ifrt_proxy/client/client_test.cc b/xla/python/ifrt_proxy/client/client_test.cc index f565990d88c174..3f1dbb45c7dea6 100644 --- a/xla/python/ifrt_proxy/client/client_test.cc +++ b/xla/python/ifrt_proxy/client/client_test.cc @@ -22,7 +22,7 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" @@ -75,50 +75,101 @@ class ClientTest : public ::testing::TestWithParam { rpc_helper_->set_host_buffer_store(host_buffer_store_); InitResponse response; - ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( - R"pb( - platform_name: "ifrt-service" - platform_version: "n/a" - platform_id: 42 - process_index: 1 - runtime_type: "ifrt-service" - devices { - id: 0 - local_hardware_id: 1234 - device_kind: "mock" - default_memory_id: 0 - memory_ids: [ 0 ] - attributes { - key: "name" - value { string_value: "device0" } + if (Version().protocol_version() <= 3) { + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + deprecated_attributes { + key: "name" + value { string_value: "device0" } + } } - } - devices { - id: 1 - local_hardware_id: 1234 - device_kind: "mock" - default_memory_id: 1 - memory_ids: [ 1 ] - attributes { - key: "name" - value { string_value: "device1" } + devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + deprecated_attributes { + key: "name" + value { string_value: "device1" } + } } - } - addressable_device_ids: 1 - memories { - id: 0 - memory_space_kind: "mock" - kind_id: 0 - device_ids: [ 0 ] - } - memories { - id: 1 - memory_space_kind: "mock" - kind_id: 1 - device_ids: [ 1 ] - } - )pb", - &response)); + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + kind_id: 0 + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + kind_id: 1 + device_ids: [ 1 ] + } + )pb", + &response)); + } else { + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + kind_id: 0 + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + kind_id: 1 + device_ids: [ 1 ] + } + )pb", + &response)); + } TF_ASSERT_OK_AND_ASSIGN(client_, Client::Create(rpc_helper_, response)); } @@ -142,9 +193,8 @@ TEST_P(ClientTest, Init) { client_->LookupDevice(DeviceId(0))); EXPECT_EQ(device0->Id(), DeviceId(0)); EXPECT_EQ(device0->Kind(), "mock"); - EXPECT_THAT(device0->Attributes(), - ElementsAre(Pair( - "name", xla::PjRtDeviceAttribute(std::string("device0"))))); + EXPECT_THAT(device0->Attributes().map(), + ElementsAre(Pair("name", AttributeMap::StringValue("device0")))); ASSERT_THAT(device0->Memories(), SizeIs(1)); auto* const memory0 = device0->Memories()[0]; @@ -157,9 +207,8 @@ TEST_P(ClientTest, Init) { client_->LookupDevice(DeviceId(1))); EXPECT_EQ(device1->Id(), 1); EXPECT_EQ(device1->Kind(), "mock"); - EXPECT_THAT(device1->Attributes(), - ElementsAre(Pair( - "name", xla::PjRtDeviceAttribute(std::string("device1"))))); + EXPECT_THAT(device1->Attributes().map(), + ElementsAre(Pair("name", AttributeMap::StringValue("device1")))); ASSERT_THAT(device1->Memories(), SizeIs(1)); auto* const memory1 = device1->Memories()[0]; diff --git a/xla/python/ifrt_proxy/client/device.cc b/xla/python/ifrt_proxy/client/device.cc index a368f395fe37d4..96b2890a5aac48 100644 --- a/xla/python/ifrt_proxy/client/device.cc +++ b/xla/python/ifrt_proxy/client/device.cc @@ -14,16 +14,28 @@ #include "xla/python/ifrt_proxy/client/device.h" +#include + #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" namespace xla { namespace ifrt { namespace proxy { +Device::Device(DeviceDescription description, int local_device_id, + int local_hardware_id, bool is_addressable) + : description_(std::move(description)), + attributes_(FromPjRtDeviceAttributeMap(description_.Attributes())), + local_device_id_(local_device_id), + local_hardware_id_(local_hardware_id), + is_addressable_(is_addressable) {} + ifrt::Client* Device::client() const { return client_; } DeviceId Device::Id() const { return DeviceId(description_.id()); } @@ -48,10 +60,7 @@ absl::StatusOr Device::DefaultMemory() const { int Device::ProcessIndex() const { return description_.process_index(); } -const absl::flat_hash_map& -Device::Attributes() const { - return description_.Attributes(); -} +const AttributeMap& Device::Attributes() const { return attributes_; } char Device::ID = 0; // NOLINT diff --git a/xla/python/ifrt_proxy/client/device.h b/xla/python/ifrt_proxy/client/device.h index cb6888d5699841..7e0c684d8b0e5a 100644 --- a/xla/python/ifrt_proxy/client/device.h +++ b/xla/python/ifrt_proxy/client/device.h @@ -27,6 +27,7 @@ #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" @@ -76,11 +77,7 @@ class DeviceDescription final : public xla::PjRtDeviceDescription { class Device final : public llvm::RTTIExtends { public: Device(DeviceDescription description, int local_device_id, - int local_hardware_id, bool is_addressable) - : description_(std::move(description)), - local_device_id_(local_device_id), - local_hardware_id_(local_hardware_id), - is_addressable_(is_addressable) {} + int local_hardware_id, bool is_addressable); ifrt::Client* client() const override; bool IsAddressable() const override; @@ -91,8 +88,7 @@ class Device final : public llvm::RTTIExtends { absl::string_view DebugString() const override; int ProcessIndex() const override; - const absl::flat_hash_map& Attributes() - const override; + const AttributeMap& Attributes() const override; absl::Span Memories() const override; absl::StatusOr DefaultMemory() const override; @@ -104,6 +100,9 @@ class Device final : public llvm::RTTIExtends { ifrt::Client* client_; const DeviceDescription description_; + + const AttributeMap attributes_; + const int local_device_id_; const int local_hardware_id_; const bool is_addressable_; diff --git a/xla/python/ifrt_proxy/client/version.h b/xla/python/ifrt_proxy/client/version.h index f88678645e8b77..f713b2301ba396 100644 --- a/xla/python/ifrt_proxy/client/version.h +++ b/xla/python/ifrt_proxy/client/version.h @@ -24,7 +24,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kClientMinVersion = 1; -inline constexpr int kClientMaxVersion = 3; +inline constexpr int kClientMaxVersion = 4; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) } // namespace proxy diff --git a/xla/python/ifrt_proxy/common/BUILD b/xla/python/ifrt_proxy/common/BUILD index ddc5e2a5670064..621b043fd72554 100644 --- a/xla/python/ifrt_proxy/common/BUILD +++ b/xla/python/ifrt_proxy/common/BUILD @@ -70,6 +70,7 @@ tf_proto_library( # copybara:uncomment "@com_google_protobuf//:any", "//xla:xla_data_proto", "//xla/pjrt:execute_options_proto", + "//xla/python/ifrt:attribute_map_proto", "//xla/python/ifrt:dtype_proto", "//xla/python/ifrt:remap_plan_proto", "//xla/python/ifrt:serdes_proto", diff --git a/xla/python/ifrt_proxy/common/VERSION.md b/xla/python/ifrt_proxy/common/VERSION.md index 854d95f55134d5..4166a27daf9ca7 100644 --- a/xla/python/ifrt_proxy/common/VERSION.md +++ b/xla/python/ifrt_proxy/common/VERSION.md @@ -17,3 +17,10 @@ * Added date: 2024-06-17. * Changes: * Added native support for `Client::CopyArrays()`. + +## Version 4 + +* Added date: 2024-06-18. +* Changes: + * Changed the serialization of client and device attributes to use `xla.ifrt.AttributeMapProto` instead of `map`. + diff --git a/xla/python/ifrt_proxy/common/ifrt_service.proto b/xla/python/ifrt_proxy/common/ifrt_service.proto index 945b66a3538045..e9f183b7467f64 100644 --- a/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -18,6 +18,7 @@ package xla.ifrt.proxy; import "google/protobuf/any.proto"; import "xla/pjrt/execute_options.proto"; +import "xla/python/ifrt/attribute_map.proto"; import "xla/python/ifrt/dtype.proto"; import "xla/python/ifrt/remap_plan.proto"; import "xla/python/ifrt/serdes.proto"; @@ -218,7 +219,9 @@ message InitResponse { repeated int32 memory_ids = 8; string debug_string = 4; string to_string = 5; - map attributes = 6; + map deprecated_attributes = 6 + [deprecated = true]; // Deprecated since Version 4. + AttributeMapProto attributes = 10; // New in Version 4. } repeated Device devices = 6; // == ifrt::Client::devices() diff --git a/xla/python/ifrt_proxy/common/types.proto b/xla/python/ifrt_proxy/common/types.proto index ca3829891d7629..49c3c7e1304570 100644 --- a/xla/python/ifrt_proxy/common/types.proto +++ b/xla/python/ifrt_proxy/common/types.proto @@ -18,6 +18,9 @@ package xla.ifrt.proto; // Mirrors `xla::PjRtValueType`, which is used in IFRT to model // polymorphic-typed values, e.g., `xla::ifrt::Executable::CostAnalysisValue`. +// +// Deprecated since Version 4. Use `xla::ifrt::AttributeMapProto::Value` +// instead. message Variant { message Int64List { repeated sfixed64 values = 1; diff --git a/xla/python/ifrt_proxy/server/BUILD b/xla/python/ifrt_proxy/server/BUILD index 7a75849446f5a8..d74970ed909e41 100644 --- a/xla/python/ifrt_proxy/server/BUILD +++ b/xla/python/ifrt_proxy/server/BUILD @@ -171,10 +171,9 @@ ifrt_proxy_cc_test( "//xla:test", "//xla:xla_data_proto_cc", "//xla/pjrt:host_callback", - "//xla/pjrt:pjrt_common", - "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", "//xla/python/ifrt:mock", "//xla/python/ifrt:serdes", "//xla/python/ifrt:sharding_serdes", diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc index 85521905214040..1b257b60263d24 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -274,9 +275,16 @@ absl::StatusOr IfrtBackend::HandleInit( } d->set_debug_string(AsProtoStringData(device->DebugString())); d->set_to_string(AsProtoStringData(device->ToString())); - for (const auto& [name, attr] : device->Attributes()) { - TF_ASSIGN_OR_RETURN((*d->mutable_attributes())[name], - ToVariantProto(attr)); + if (version_.protocol_version() <= 3) { + for (const auto& [name, attr] : device->Attributes().map()) { + TF_ASSIGN_OR_RETURN( + (*d->mutable_deprecated_attributes())[name], + std::visit( + [&](const auto& attr) { return ToVariantProto(attr.value); }, + attr)); + } + } else { + *d->mutable_attributes() = device->Attributes().ToProto(); } } for (auto* addressable_device : client_->addressable_devices()) { diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index f54ea4effd7371..91b428bd2aa26d 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -41,10 +41,9 @@ #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/host_callback.h" -#include "xla/pjrt/pjrt_common.h" -#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" @@ -397,12 +396,13 @@ TEST_P(IfrtBackendHandlerTest, Init) { device_memories.push_back({&mock_memories[i]}); } - using AttributeMap = - absl::flat_hash_map; - std::vector device_attributes(mock_devices_.size()); + std::vector device_attributes; + device_attributes.reserve(mock_devices_.size()); for (int i = 0; i < mock_devices_.size(); ++i) { - device_attributes[i].insert({"name", absl::StrCat("device", i)}); + AttributeMap::Map map; + map.insert({"name", AttributeMap::StringValue(absl::StrCat("device", i))}); + device_attributes.push_back(AttributeMap(std::move(map))); MockDevice& mock_device = *mock_devices_[i]; // TODO(b/314368788): Clean up PJRT device ID APIs. @@ -418,48 +418,97 @@ TEST_P(IfrtBackendHandlerTest, Init) { auto request = NewIfrtRequest(NewOpId()); request->mutable_init_request(); - EXPECT_THAT(CallBackend(std::move(request)), - IsOkAndHolds(Pointee( - Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( - init_response { - session_id: 12345 - platform_name: "ifrt_backend" - platform_version: "n/a" - platform_id: 42 - process_index: 1 - runtime_type: "ifrt-service" - devices { - id: 0 - device_kind: "mock" - default_memory_id: 0 - memory_ids: [ 0 ] - attributes { - key: "name" - value { string_value: "device0" } + if (Version().protocol_version() <= 3) { + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + deprecated_attributes { + key: "name" + value { string_value: "device0" } + } } - } - devices { - id: 1 - device_kind: "mock" - default_memory_id: 1 - memory_ids: [ 1 ] - attributes { - key: "name" - value { string_value: "device1" } + devices { + id: 1 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + deprecated_attributes { + key: "name" + value { string_value: "device1" } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] } } - memories { - id: 0 - memory_space_kind: "mock" - device_ids: [ 0 ] - } - memories { - id: 1 - memory_space_kind: "mock" - device_ids: [ 1 ] + )pb")))))); + } else { + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + devices { + id: 1 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } } - } - )pb")))))); + )pb")))))); + } } #endif diff --git a/xla/python/ifrt_proxy/server/version.h b/xla/python/ifrt_proxy/server/version.h index 4efd3648ab94e3..686fe78993bfd2 100644 --- a/xla/python/ifrt_proxy/server/version.h +++ b/xla/python/ifrt_proxy/server/version.h @@ -26,7 +26,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kServerMinVersion = 1; -inline constexpr int kServerMaxVersion = 3; +inline constexpr int kServerMaxVersion = 4; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) // Returns a version that both the client and the server support, or an error if diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index 50e2e4ba6382cb..ce9502d19f0dd5 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -213,6 +213,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":basic_string_array", + ":pjrt_attribute_map_util", ":xla_ifrt", "//xla:literal", "//xla:shape_util", @@ -233,6 +234,7 @@ cc_library( "//xla/pjrt/distributed:protocol_proto_cc", "//xla/pjrt/distributed:topology_util", "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", "//xla/python/ifrt/hlo:hlo_program", "//xla/service:hlo_proto_cc", "//xla/translate/mhlo_to_hlo:type_to_shape", diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index f4a4e4c79687a4..8156305c0b5de4 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" @@ -64,6 +65,7 @@ limitations under the License. #include "xla/python/ifrt/value.h" #include "xla/python/pjrt_ifrt/basic_string_array.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_remap.h" @@ -89,6 +91,26 @@ absl::AnyInvocable FromStdFunction(std::function&& f) { return f ? std::move(f) : absl::AnyInvocable(); } +// Returns an `AttributeMap` with the attributes of the given `PjRtClient`. +AttributeMap MakeAttributeMap(xla::PjRtClient* pjrt_client) { + absl::flat_hash_map attributes; + attributes.insert({"supports_executable_serialization", true}); + if (std::optional plugin_attributes = + pjrt_client->plugin_attributes(); + plugin_attributes.has_value()) { + attributes.insert( + {"pjrt_c_api_major_version", + PjRtValueType(plugin_attributes->pjrt_c_api_major_version)}); + attributes.insert( + {"pjrt_c_api_minor_version", + PjRtValueType(plugin_attributes->pjrt_c_api_minor_version)}); + for (const auto& [key, value] : plugin_attributes->attributes) { + attributes.insert({key, value}); + } + } + return FromPjRtDeviceAttributeMap(std::move(attributes)); +} + void SerializePjRtDeviceAttributes( const absl::flat_hash_map& attributes, DeviceProto& device_proto) { @@ -455,7 +477,9 @@ std::unique_ptr PjRtClient::Create( } PjRtClient::PjRtClient(std::shared_ptr pjrt_client) - : pjrt_client_(std::move(pjrt_client)), default_compiler_(this) {} + : pjrt_client_(std::move(pjrt_client)), + default_compiler_(this), + attributes_(MakeAttributeMap(pjrt_client_.get())) {} PjRtClient::~PjRtClient() = default; @@ -498,27 +522,7 @@ absl::StatusOr PjRtClient::LookupAddressableDevice( return LookupPjRtDevice(pjrt_device); } -absl::flat_hash_map -PjRtClient::attributes() const { - absl::flat_hash_map attributes; - attributes.insert({"supports_executable_serialization", true}); - - if (std::optional plugin_attributes = - pjrt_client_->plugin_attributes(); - plugin_attributes.has_value()) { - attributes.insert( - {"pjrt_c_api_major_version", - ClientAttribute(plugin_attributes->pjrt_c_api_major_version)}); - attributes.insert( - {"pjrt_c_api_minor_version", - ClientAttribute(plugin_attributes->pjrt_c_api_minor_version)}); - for (const auto& [key, value] : plugin_attributes->attributes) { - attributes.insert({key, value}); - } - } - - return attributes; -} +const AttributeMap& PjRtClient::Attributes() const { return attributes_; } absl::StatusOr> PjRtClient::CreatePjRtArray(std::shared_ptr pjrt_buffer) { diff --git a/xla/python/pjrt_ifrt/pjrt_client.h b/xla/python/pjrt_ifrt/pjrt_client.h index b7841f7a043606..51915711f3aeae 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.h +++ b/xla/python/pjrt_ifrt/pjrt_client.h @@ -37,6 +37,7 @@ limitations under the License. #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -184,7 +185,7 @@ class PjRtClient final return pjrt_client_->platform_id(); } - absl::flat_hash_map attributes() const override; + const AttributeMap& Attributes() const override; int device_count() const override { DCHECK(this); @@ -247,6 +248,8 @@ class PjRtClient final std::shared_ptr pjrt_client_; PjRtCompiler default_compiler_; + AttributeMap attributes_; + std::vector> owned_devices_; std::vector> owned_memories_; diff --git a/xla/python/pjrt_ifrt/pjrt_device.cc b/xla/python/pjrt_ifrt/pjrt_device.cc index 48a7225e3d7bd3..633fb0674743e1 100644 --- a/xla/python/pjrt_ifrt/pjrt_device.cc +++ b/xla/python/pjrt_ifrt/pjrt_device.cc @@ -24,8 +24,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" @@ -43,15 +45,17 @@ PjRtDevice::PjRtDevice( xla::PjRtDevice* pjrt_device) : client_(client), id_(id), + attributes_(FromPjRtDeviceAttributeMap(std::move(attributes))), kind_(std::move(kind)), to_string_(std::move(to_string)), debug_string_(std::move(debug_string)), process_index_(process_index), - attributes_(std::move(attributes)), pjrt_device_(pjrt_device) {} DeviceId PjRtDevice::Id() const { return id_; } +const AttributeMap& PjRtDevice::Attributes() const { return attributes_; } + absl::string_view PjRtDevice::Kind() const { return kind_; } absl::string_view PjRtDevice::ToString() const { return to_string_; } @@ -68,10 +72,5 @@ absl::Span PjRtDevice::Memories() const { return memories_; } int PjRtDevice::ProcessIndex() const { return process_index_; } -const absl::flat_hash_map& -PjRtDevice::Attributes() const { - return attributes_; -} - } // namespace ifrt } // namespace xla diff --git a/xla/python/pjrt_ifrt/pjrt_device.h b/xla/python/pjrt_ifrt/pjrt_device.h index dd2790615785e9..596db196304df4 100644 --- a/xla/python/pjrt_ifrt/pjrt_device.h +++ b/xla/python/pjrt_ifrt/pjrt_device.h @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" @@ -55,6 +56,7 @@ class PjRtDevice final PjRtClient* client() const override { return client_; } DeviceId Id() const final; + const AttributeMap& Attributes() const final; absl::string_view Kind() const final; absl::string_view ToString() const final; absl::string_view DebugString() const final; @@ -62,8 +64,6 @@ class PjRtDevice final absl::StatusOr DefaultMemory() const final; absl::Span Memories() const final; int ProcessIndex() const final; - const absl::flat_hash_map& Attributes() - const final; static char ID; // NOLINT @@ -73,13 +73,13 @@ class PjRtDevice final PjRtClient* client_; DeviceId id_; + AttributeMap attributes_; std::string kind_; std::string to_string_; std::string debug_string_; absl::StatusOr default_memory_; std::vector memories_; int process_index_; - absl::flat_hash_map attributes_; xla::PjRtDevice* pjrt_device_; }; diff --git a/xla/python/pjrt_ifrt/pjrt_topology.cc b/xla/python/pjrt_ifrt/pjrt_topology.cc index 6d76b16c02233c..0bde97a07eb3a3 100644 --- a/xla/python/pjrt_ifrt/pjrt_topology.cc +++ b/xla/python/pjrt_ifrt/pjrt_topology.cc @@ -21,13 +21,14 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/layout.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" namespace xla::ifrt { @@ -35,7 +36,8 @@ char PjRtTopology::ID = 0; PjRtTopology::PjRtTopology( std::shared_ptr description) - : description_(std::move(description)) {} + : description_(std::move(description)), + attributes_(FromPjRtDeviceAttributeMap(description_->Attributes())) {} absl::string_view PjRtTopology::platform_name() const { return description_->platform_name(); @@ -63,9 +65,6 @@ absl::StatusOr PjRtTopology::Serialize() const { return description_->Serialize(); } -const absl::flat_hash_map& -PjRtTopology::Attributes() const { - return description_->Attributes(); -} +const AttributeMap& PjRtTopology::Attributes() const { return attributes_; } } // namespace xla::ifrt diff --git a/xla/python/pjrt_ifrt/pjrt_topology.h b/xla/python/pjrt_ifrt/pjrt_topology.h index c854cb24862fc7..82fc59c8005c01 100644 --- a/xla/python/pjrt_ifrt/pjrt_topology.h +++ b/xla/python/pjrt_ifrt/pjrt_topology.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -29,6 +28,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/topology.h" namespace xla::ifrt { @@ -56,13 +56,13 @@ class PjRtTopology final : public llvm::RTTIExtends { absl::StatusOr Serialize() const override; - const absl::flat_hash_map& Attributes() - const override; + const AttributeMap& Attributes() const override; static char ID; // NOLINT private: std::shared_ptr description_; + AttributeMap attributes_; }; } // namespace xla::ifrt diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 10474e6af4a8c6..7a0cffb86c4ed1 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -113,7 +113,7 @@ namespace nb = nanobind; PyClient::PyClient(std::shared_ptr ifrt_client) : ifrt_client_(std::move(ifrt_client)), - client_attributes_(ifrt_client_->attributes()) { + client_attributes_(ifrt_client_->Attributes()) { CHECK(ifrt_client_); } @@ -763,10 +763,10 @@ PyType_Slot PyClient::slots_[] = { nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) .def("__getattr__", [](PyClient& client, std::string_view name) -> nb::object { - const auto& attrs = client.attributes(); + const auto& attrs = client.Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { - return std::visit([](auto&& v) { return nb::cast(v); }, + return std::visit([](auto&& v) { return nb::cast(v.value); }, it->second); } throw nb::attribute_error( diff --git a/xla/python/py_client.h b/xla/python/py_client.h index c95e03bff1097e..79dac7a4f0eeb1 100644 --- a/xla/python/py_client.h +++ b/xla/python/py_client.h @@ -36,9 +36,11 @@ limitations under the License. #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/program.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/shape.h" @@ -112,8 +114,7 @@ class PyClient { // Returns implementation-specific attributes about this client, e.g. the PJRT // C API version if applicable. - const absl::flat_hash_map& - attributes() const { + const xla::ifrt::AttributeMap& Attributes() const { return client_attributes_; } @@ -236,8 +237,7 @@ class PyClient { static PyType_Slot slots_[]; std::shared_ptr ifrt_client_; - absl::flat_hash_map - client_attributes_; + xla::ifrt::AttributeMap client_attributes_; // Pointers to intrusive doubly-linked lists of arrays and executables, used // to iterate over all known objects when heap profiling. The list structure // is protected by the GIL. diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index 91604fe53a66b7..49b7000def469f 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -45,6 +44,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -61,6 +61,7 @@ limitations under the License. #include "xla/python/ifrt/value.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" @@ -82,7 +83,9 @@ class CompileOnlyDevice : public llvm::RTTIExtends { public: explicit CompileOnlyDevice(const PjRtDeviceDescription* description) - : description_(std::move(description)) {} + : description_(std::move(description)), + attributes_( + ifrt::FromPjRtDeviceAttributeMap(description_->Attributes())) {} const PjRtDeviceDescription& description() const { return *description_; } @@ -111,13 +114,11 @@ class CompileOnlyDevice return Unimplemented("DefaultMemory is not supported"); } - const absl::flat_hash_map& Attributes() - const { - return description_->Attributes(); - } + const ifrt::AttributeMap& Attributes() const override { return attributes_; } private: const PjRtDeviceDescription* description_; + ifrt::AttributeMap attributes_; }; class InvalidIfrtCompiler final @@ -151,7 +152,8 @@ class CompileOnlyIfRtClient final public: explicit CompileOnlyIfRtClient(std::shared_ptr topology) : topology_(std::move(topology)), - descriptions_(topology_->DeviceDescriptions()) { + descriptions_(topology_->DeviceDescriptions()), + attributes_(ifrt::AttributeMap::Map()) { for (auto& description : descriptions_) { owned_devices_.push_back( std::make_unique(description.get())); @@ -218,10 +220,7 @@ class CompileOnlyIfRtClient final ifrt::PlatformId platform_id() const override { return topology_->platform_id(); } - absl::flat_hash_map attributes() - const override { - return {}; - } + const ifrt::AttributeMap& Attributes() const override { return attributes_; } int device_count() const override { return devices().size(); } int addressable_device_count() const override { return 0; } @@ -271,6 +270,7 @@ class CompileOnlyIfRtClient final InvalidIfrtCompiler default_compiler_; std::shared_ptr topology_; std::vector> descriptions_; + ifrt::AttributeMap attributes_; std::vector> owned_devices_; std::vector devices_; }; diff --git a/xla/python/py_device.cc b/xla/python/py_device.cc index 3ba6d2a6e76208..371a6826e28480 100644 --- a/xla/python/py_device.cc +++ b/xla/python/py_device.cc @@ -322,11 +322,11 @@ PyType_Slot PyDevice::slots_[] = { try { auto device = nb::cast(nb::handle(self)); auto name = nb::cast(nb::handle(key)); - const auto& attrs = device->device_->Attributes(); + const auto& attrs = device->device_->Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { - auto result = - std::visit([](auto&& v) { return nb::cast(v); }, it->second); + auto result = std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); return result.release().ptr(); } PyErr_SetNone(PyExc_AttributeError); diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 8e8842b31c2521..2eecf7d45fed42 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -840,10 +840,10 @@ NB_MODULE(xla_extension, m_nb) { }) .def("__getattr__", [](ifrt::Topology& topology, std::string_view name) -> nb::object { - const auto& attrs = topology.Attributes(); + const auto& attrs = topology.Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { - return std::visit([](auto&& v) { return nb::cast(v); }, + return std::visit([](auto&& v) { return nb::cast(v.value); }, it->second); } throw nb::attribute_error(