Skip to content

Commit

Permalink
[IFRT] Add Client::GetAllDevices()
Browse files Browse the repository at this point in the history
This defines `Client::GetAllDevices()`. It is similar to `Client::devices()`,
but it enumerates all devices available on the client, regardless of the
type/kind of devices. This multi-device behavior was implemented on certain
IFRT implementations, and PjRt-IFRT would return the same devices for now
because it has only one PjRt client. In the future, however, PjRt-IFRT would
support multiple device types and it `GetAllDevices()` will return more devices
while `devices()` may keep returning the devices of the primary device
type/kind.

`Client::GetAllDevices()` is essentially a transitional API in that its role
will be absorbed back to `Client::devices()` by making `Client::devices()` also
return all devices. However, this can only happen after the user code has been
updated to apply device filtering so that the legacy behavior around using
`Client::devices()` remains unchanged. By having a separate
`Client::GetAllDevices()` while the transition happens, we can incrementally
migrate the user code to apply the device filtering.

PiperOrigin-RevId: 679748878
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Oct 4, 2024
1 parent 6eb346b commit 5db9b92
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 27 deletions.
1 change: 1 addition & 0 deletions xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ cc_library(
deps = [
":ifrt",
":test_util",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
alwayslink = True,
Expand Down
5 changes: 5 additions & 0 deletions xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
virtual absl::Span<Device* const> addressable_devices() const = 0;
virtual int process_index() const = 0;

// Returns all devices. The result includes primary devices that are included
// in `devices()` as well as any other devices that are associated with
// the primary devices.
virtual absl::Span<xla::ifrt::Device* const> GetAllDevices() const = 0;

// TODO(hyeontaek): Consider removing this API. This API is potentially not
// being used by JAX or will be replaced with explicit device assignment.
virtual absl::StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
Expand Down
15 changes: 14 additions & 1 deletion xla/python/ifrt/client_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
Expand Down Expand Up @@ -54,6 +55,18 @@ TEST(ClientImplTest, Devices) {
EXPECT_GE(client->process_index(), 0);
}

TEST(ClientImplTest, GetAllDevices) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());

EXPECT_GE(client->GetAllDevices().size(), client->device_count());

for (Device* device : client->GetAllDevices()) {
TF_ASSERT_OK_AND_ASSIGN(auto* looked_up_device,
client->LookupDevice(device->Id()));
EXPECT_EQ(device, looked_up_device);
}
}

TEST(ClientImplTest, DefaultCompiler) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
EXPECT_THAT(client->GetDefaultCompiler(), NotNull());
Expand Down
3 changes: 3 additions & 0 deletions xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
ON_CALL(*this, process_index).WillByDefault([this]() {
return delegated_->process_index();
});
ON_CALL(*this, GetAllDevices).WillByDefault([this]() {
return delegated_->GetAllDevices();
});
ON_CALL(*this, GetDefaultDeviceAssignment)
.WillByDefault([this](int num_replicas, int num_partitions) {
return delegated_->GetDefaultDeviceAssignment(num_replicas,
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
MOCK_METHOD(absl::Span<Device* const>, addressable_devices, (),
(const, final));
MOCK_METHOD(int, process_index, (), (const, final));
MOCK_METHOD(absl::Span<xla::ifrt::Device* const>, GetAllDevices, (),
(const, final));
MOCK_METHOD(absl::StatusOr<DeviceAssignment>, GetDefaultDeviceAssignment,
(int num_replicas, int num_partitions), (const, final));
MOCK_METHOD(absl::StatusOr<Device*>, LookupDevice, (DeviceId device_id),
Expand Down
40 changes: 33 additions & 7 deletions xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
absl::flat_hash_set<int> addressable_device_ids(
init_response.addressable_device_ids().begin(),
init_response.addressable_device_ids().end());
absl::flat_hash_set<int> primary_device_ids;
if (rpc_helper->version().protocol_version() < 7) {
// Legacy implementation for servers do not support Client::GetAllDevices()
// and thus do not provide device_ids(). Assume that it contains all device
// ids from devices().
primary_device_ids.reserve(init_response.all_devices().size());
for (const auto& d : init_response.all_devices()) {
primary_device_ids.insert(d.id());
}
} else {
primary_device_ids.reserve(init_response.primary_device_ids().size());
primary_device_ids.insert(init_response.primary_device_ids().begin(),
init_response.primary_device_ids().end());
}

absl::flat_hash_map<int, std::unique_ptr<Memory>> memories;
for (const auto& m : init_response.memories()) {
Expand All @@ -77,10 +91,11 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
}

absl::flat_hash_map<int, std::unique_ptr<Device>> devices;
std::vector<xla::ifrt::Device*> device_ptrs;
std::vector<xla::ifrt::Device*> primary_device_ptrs;
std::vector<xla::ifrt::Device*> addressable_device_ptrs;
std::vector<xla::ifrt::Device*> all_device_ptrs;

for (const auto& d : init_response.devices()) {
for (const auto& d : init_response.all_devices()) {
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>
pjrt_device_attributes;
if (rpc_helper->version().protocol_version() <= 3) {
Expand All @@ -99,14 +114,18 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
d.device_kind(), d.debug_string(), d.to_string(),
std::move(pjrt_device_attributes));
bool is_addressable = addressable_device_ids.contains(d.id());
bool is_primary = primary_device_ids.contains(d.id());

auto device =
std::make_unique<Device>(std::move(desc), d.local_device_id(),
d.local_hardware_id(), is_addressable);
device_ptrs.push_back(device.get());
all_device_ptrs.push_back(device.get());
if (is_addressable) {
addressable_device_ptrs.push_back(device.get());
}
if (is_primary) {
primary_device_ptrs.push_back(device.get());
}

if (d.has_default_memory_id()) {
const auto it = memories.find(d.default_memory_id());
Expand Down Expand Up @@ -150,9 +169,10 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
std::move(rpc_helper), init_response.session_id(),
init_response.platform_name(), init_response.platform_version(),
init_response.platform_id(), init_response.process_index(), runtime_type,
std::move(devices), device_ptrs, std::move(addressable_device_ptrs),
std::move(devices), std::move(primary_device_ptrs),
std::move(addressable_device_ptrs), all_device_ptrs,
std::move(memories)));
for (ifrt::Device* device : device_ptrs) {
for (ifrt::Device* device : all_device_ptrs) {
tensorflow::down_cast<Device*>(device)->client_ = client.get();
}
return client;
Expand All @@ -163,8 +183,9 @@ Client::Client(std::shared_ptr<RpcHelper> rpc_helper, uint64_t session_id,
uint64_t platform_id, uint64_t process_index,
std::string runtime_type,
absl::flat_hash_map<int, std::unique_ptr<Device>> devices,
std::vector<xla::ifrt::Device*> device_ptrs,
std::vector<xla::ifrt::Device*> primary_device_ptrs,
std::vector<xla::ifrt::Device*> addressable_device_ptrs,
std::vector<xla::ifrt::Device*> all_device_ptrs,
absl::flat_hash_map<int, std::unique_ptr<Memory>> memories)
: rpc_helper_(rpc_helper),
platform_name_(std::move(platform_name)),
Expand All @@ -175,8 +196,9 @@ Client::Client(std::shared_ptr<RpcHelper> rpc_helper, uint64_t session_id,
// TODO(b/309059940): Forward the backend attributes to the client.
attributes_(AttributeMap::Map()),
devices_(std::move(devices)),
device_ptrs_(device_ptrs),
primary_device_ptrs_(primary_device_ptrs),
addressable_device_ptrs_(std::move(addressable_device_ptrs)),
all_device_ptrs_(all_device_ptrs),
memories_(std::move(memories)),
default_compiler_(this, rpc_helper) {}

Expand Down Expand Up @@ -302,6 +324,10 @@ xla::ifrt::Future<> Client::GetReadyFuture(
return JoinFutures(futures);
}

absl::Span<xla::ifrt::Device* const> Client::GetAllDevices() const {
return all_device_ptrs_;
}

absl::StatusOr<DeviceAssignment> Client::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
auto req = std::make_unique<GetDefaultDeviceAssignmentRequest>();
Expand Down
9 changes: 6 additions & 3 deletions xla/python/ifrt_proxy/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
return addressable_devices().size();
}
absl::Span<xla::ifrt::Device* const> devices() const override {
return device_ptrs_;
return primary_device_ptrs_;
}
absl::Span<xla::ifrt::Device* const> addressable_devices() const override {
return addressable_device_ptrs_;
}
int process_index() const override { return process_index_; }
absl::Span<xla::ifrt::Device* const> GetAllDevices() const override;
absl::StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
absl::StatusOr<xla::ifrt::Device*> LookupDevice(
Expand Down Expand Up @@ -148,8 +149,9 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
std::string platform_name, std::string platform_version,
uint64_t platform_id, uint64_t process_index, std::string runtime_type,
absl::flat_hash_map<int, std::unique_ptr<Device>> devices,
std::vector<xla::ifrt::Device*> device_ptrs,
std::vector<xla::ifrt::Device*> primary_device_ptrs,
std::vector<xla::ifrt::Device*> addressable_device_ptrs,
std::vector<xla::ifrt::Device*> all_device_ptrs,
absl::flat_hash_map<int, std::unique_ptr<Memory>> memories);

// rpc_helper_ will be referenced by various IFRT objects whose lifetime is
Expand All @@ -166,8 +168,9 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
const AttributeMap attributes_;

const absl::flat_hash_map<int, std::unique_ptr<Device>> devices_;
const std::vector<xla::ifrt::Device*> device_ptrs_;
const std::vector<xla::ifrt::Device*> primary_device_ptrs_;
const std::vector<xla::ifrt::Device*> addressable_device_ptrs_;
const std::vector<xla::ifrt::Device*> all_device_ptrs_;

const absl::flat_hash_map<int, std::unique_ptr<Memory>> memories_;

Expand Down
58 changes: 54 additions & 4 deletions xla/python/ifrt_proxy/client/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
platform_id: 42
process_index: 1
runtime_type: "ifrt-service"
devices {
all_devices {
id: 0
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -94,7 +94,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
value { string_value: "device0" }
}
}
devices {
all_devices {
id: 1
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -120,6 +120,55 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
}
)pb",
&response));
} else if (Version().protocol_version() < 7) {
ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString(
R"pb(
platform_name: "ifrt-service"
platform_version: "n/a"
platform_id: 42
process_index: 1
runtime_type: "ifrt-service"
all_devices {
id: 0
local_hardware_id: 1234
device_kind: "mock"
default_memory_id: 0
memory_ids: [ 0 ]
attributes {
attributes {
key: "name"
value { string_value: "device0" }
}
}
}
all_devices {
id: 1
local_hardware_id: 1234
device_kind: "mock"
default_memory_id: 1
memory_ids: [ 1 ]
attributes {
attributes {
key: "name"
value { string_value: "device1" }
}
}
}
addressable_device_ids: 1
memories {
id: 0
memory_space_kind: "mock"
kind_id: 0
device_ids: [ 0 ]
}
memories {
id: 1
memory_space_kind: "mock"
kind_id: 1
device_ids: [ 1 ]
}
)pb",
&response));
} else {
ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString(
R"pb(
Expand All @@ -128,7 +177,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
platform_id: 42
process_index: 1
runtime_type: "ifrt-service"
devices {
all_devices {
id: 0
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -141,7 +190,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
}
}
}
devices {
all_devices {
id: 1
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -154,6 +203,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
}
}
}
primary_device_ids: [ 0, 1 ]
addressable_device_ids: 1
memories {
id: 0
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt_proxy/client/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace proxy {
// LINT.IfChange
// TODO(b/296144873): Document the version upgrade policy.
inline constexpr int kClientMinVersion = 3;
inline constexpr int kClientMaxVersion = 6;
inline constexpr int kClientMaxVersion = 7;
// LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md)

} // namespace proxy
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt_proxy/common/VERSION.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@
* Added date: 2024-09-30.
* Changes:
* Added `ExecuteOptions::fill_status`.

## Version 7

* Added date: 2024-10-01.
* Changes:
* Added support for `Client::GetAllDevices()`.
7 changes: 5 additions & 2 deletions xla/python/ifrt_proxy/common/ifrt_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,12 @@ message InitResponse {
AttributeMapProto attributes = 10; // New in Version 4.
}

repeated Device devices = 6; // == ifrt::Client::devices()
repeated Device all_devices = 6; // == ifrt::Client::GetAllDevices()
repeated int32 primary_device_ids =
10; // == [device.id for device in ifrt::Client::devices()]
repeated int32 addressable_device_ids =
7; // == ifrt::Client::addressable_devices()
7; // == [device.id for device in ifrt::Client::GetAllDevices() if
// device.IsAddressable()]

message Memory {
int32 id = 1;
Expand Down
20 changes: 15 additions & 5 deletions xla/python/ifrt_proxy/server/ifrt_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,14 @@ absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleInit(
init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type()));
init_resp->set_process_index(client_->process_index());

for (auto* device : client_->devices()) {
InitResponse::Device* d = init_resp->add_devices();
absl::Span<xla::ifrt::Device* const> all_devices;
if (version_.protocol_version() < 7) {
all_devices = client_->devices();
} else {
all_devices = client_->GetAllDevices();
}
for (auto* device : all_devices) {
InitResponse::Device* d = init_resp->add_all_devices();
d->set_id(device->Id().value());
d->set_device_kind(AsProtoStringData(device->Kind()));
if (auto default_memory = device->DefaultMemory(); default_memory.ok()) {
Expand All @@ -289,13 +295,17 @@ absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleInit(
} else {
*d->mutable_attributes() = device->Attributes().ToProto();
}

if (device->IsAddressable()) {
init_resp->add_addressable_device_ids(device->Id().value());
}
}
for (auto* addressable_device : client_->addressable_devices()) {
init_resp->add_addressable_device_ids(addressable_device->Id().value());
for (auto* device : client_->devices()) {
init_resp->add_primary_device_ids(device->Id().value());
}

absl::flat_hash_map<int, xla::ifrt::Memory*> memories;
for (auto* device : client_->devices()) {
for (auto* device : all_devices) {
for (xla::ifrt::Memory* memory : device->Memories()) {
const auto [it, inserted] =
memories.insert({memory->Id().value(), memory});
Expand Down
Loading

0 comments on commit 5db9b92

Please sign in to comment.