Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IFRT proxy: asynchronous and faster MakeArrayFromHostBuffer #19407

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cc_library(
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:unbounded_work_queue",
"@tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -97,6 +98,7 @@ cc_library(
":host_buffer",
"//xla/python/ifrt",
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"//xla/python/ifrt_proxy/common:prof_util",
"//xla/python/ifrt_proxy/common:test_utils",
"//xla/python/ifrt_proxy/common:types",
"//xla/tsl/profiler/utils:xplane_schema",
Expand Down Expand Up @@ -226,6 +228,7 @@ cc_library(
srcs = ["array.cc"],
hdrs = ["array.h"],
deps = [
":global_flags",
":rpc_helper",
"//xla:status_macros",
"//xla/python/ifrt",
Expand All @@ -237,6 +240,7 @@ cc_library(
"//xla/tsl/concurrency:ref_count",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -286,7 +290,6 @@ cc_library(
"//xla/python/ifrt",
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

Expand Down Expand Up @@ -401,7 +404,6 @@ cc_library(
hdrs = ["host_buffer.h"],
deps = [
"//xla/python/ifrt",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
Expand All @@ -415,8 +417,6 @@ cc_library(
deps = [
":host_buffer",
"//xla/python/ifrt",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_googletest//:gtest",
Expand All @@ -433,14 +433,14 @@ cc_library(
"//xla/python/ifrt",
"//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto",
"//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc",
"//xla/python/ifrt_proxy/common:prof_util",
"//xla/tsl/protobuf:status_proto_cc",
"@com_github_grpc_grpc//:grpc++",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:unbounded_work_queue",
],
)
Expand All @@ -450,6 +450,7 @@ cc_library(
srcs = ["grpc_client.cc"],
deps = [
":client",
":global_flags",
":grpc_client_session",
":grpc_host_buffer",
":registry",
Expand All @@ -464,7 +465,6 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/log:log_entry",
"@com_google_absl//absl/log:log_sink",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
Expand All @@ -480,9 +480,11 @@ cc_library(
srcs = ["registry.cc"],
hdrs = ["registry.h"],
deps = [
":global_flags",
"//xla/python/ifrt",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -574,3 +576,28 @@ cc_library(
"@tsl//tsl/platform:statusor",
],
)

# Export headers referenced by the google-internal-version of global_flags.
exports_files(
["global_flags.h"],
visibility = if_google(
["//xla/python/ifrt_proxy/client/google:__pkg__"],
["//visibility:private"],
),
)

cc_library(
name = "global_flags_oss",
srcs = [
"global_flags.h",
"global_flags_oss.cc",
],
visibility = ["//visibility:private"],
deps = ["@com_google_absl//absl/debugging:leak_check"],
)

cc_library(
name = "global_flags",
hdrs = ["global_flags.h"],
deps = [if_google("//xla/python/ifrt_proxy/client/google:global_flags_google", ":global_flags_oss")],
)
131 changes: 100 additions & 31 deletions xla/python/ifrt_proxy/client/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "xla/python/ifrt_proxy/client/array.h"

#include <cstdint>
#include <cstring>
#include <functional>
#include <memory>
#include <optional>
Expand All @@ -23,6 +24,7 @@
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand All @@ -41,6 +43,7 @@
#include "xla/python/ifrt/remap_plan.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/ifrt_proxy/client/global_flags.h"
#include "xla/python/ifrt_proxy/client/rpc_helper.h"
#include "xla/python/ifrt_proxy/common/array_util.h"
#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h"
Expand All @@ -57,62 +60,129 @@ namespace proxy {

char Array::ID = 0;

using HostBufferSemantics = ::xla::ifrt::Client::HostBufferSemantics;

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>
Array::MakeArrayFromHostBuffer(
xla::ifrt::Client* client, std::shared_ptr<RpcHelper> rpc_helper,
const void* data, DType dtype, Shape shape,
std::optional<absl::Span<const int64_t>> byte_strides,
std::shared_ptr<const Sharding> sharding,
xla::ifrt::Client::HostBufferSemantics semantics,
std::shared_ptr<const Sharding> sharding, HostBufferSemantics semantics,
std::function<void()> on_done_with_host_buffer) {
const uint64_t host_buffer_handle =
rpc_helper->host_buffer_store()->NextHandle();

if (dtype.kind() == DType::kString) {
absl::string_view mem_region;
if (dtype.kind() != DType::kString) {
TF_ASSIGN_OR_RETURN(
auto array_mem_region,
ArrayMemRegion::FromZerothElementPointer(
/*zeroth_element=*/data, dtype, shape, byte_strides));
mem_region = array_mem_region.mem_region();
} else {
// DType::kString
if (rpc_helper->version().protocol_version() < 9) {
return absl::UnimplementedError(
"String arrays are not supported in ifrt-proxy version < 9");
}
TF_ASSIGN_OR_RETURN(
const std::string serialized_string_buffer,
std::shared_ptr<std::string> owned_data,
SerializeStringHostBuffer(absl::MakeConstSpan(
static_cast<const absl::Cord*>(data), shape.num_elements())));
mem_region = *owned_data;
semantics = HostBufferSemantics::kImmutableUntilTransferCompletes;
std::function<void()> on_done(std::move(on_done_with_host_buffer));
on_done_with_host_buffer = [owned_data = std::move(owned_data),
on_done = std::move(on_done)]() {
if (on_done) {
std::move(on_done)();
}
};
}

const uint64_t host_buffer_handle = rpc_helper->NextHandle();

if (GetGlobalClientFlags()->synchronous_host_buffer_store ||
rpc_helper->version().protocol_version() < 10) {
// Synchronously send data and await.
TF_RETURN_IF_ERROR(rpc_helper->host_buffer_store()
->Store(host_buffer_handle,
absl::string_view(serialized_string_buffer))
->Store(host_buffer_handle, mem_region)
.Await());
if (on_done_with_host_buffer != nullptr) {
std::move(on_done_with_host_buffer)();
}
} else {
TF_ASSIGN_OR_RETURN(
const auto array_mem_region,
ArrayMemRegion::FromZerothElementPointer(
/*zeroth_element=*/data, dtype, shape, byte_strides));
TF_RETURN_IF_ERROR(
rpc_helper->host_buffer_store()
->Store(host_buffer_handle, array_mem_region.mem_region())
.Await());
// Asynchronously send data.

if (semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
auto alloc = std::make_shared<char[]>(mem_region.size());
memcpy(&alloc[0], mem_region.data(), mem_region.size());
mem_region = absl::string_view(&alloc[0], mem_region.size());
if (on_done_with_host_buffer != nullptr) {
std::move(on_done_with_host_buffer)();
}
on_done_with_host_buffer = [alloc = std::move(alloc)]() {};
}

// If the async-send results in an error, ignoring it may mean that the
// control-path hangs forever. Instead, we explicitly ensure the
// control-path gets disconnected (and so the entire session ends).
//
// While there are more fine-grained approaches to handle errors, we do not
// expect an error except for one that indicates being already disconnected
// from the server.
rpc_helper->host_buffer_store()
->Store(host_buffer_handle, mem_region)
.OnReady([on_done = std::move(on_done_with_host_buffer),
rpc_helper = std::weak_ptr<RpcHelper>(rpc_helper)](
absl::Status s) mutable {
if (!s.ok()) {
LOG(WARNING) << "Handling error in background data-transfer by "
<< "disconnecting from server (if not already "
<< "disconnected), error: " << s;
if (auto locked = rpc_helper.lock()) {
locked->Disconnect();
}
};
if (on_done != nullptr) {
std::move(on_done)();
}
});
}

auto req = std::make_unique<MakeArrayFromHostBufferRequest>();
req->set_host_buffer_handle(host_buffer_handle);
// Reuse the host_buffer_handle as also the client-manufactured
// array_handle.
req->set_array_handle(host_buffer_handle);
*req->mutable_dtype() = dtype.ToProto();
*req->mutable_shape() = shape.ToProto();
TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), sharding->ToProto());
if (byte_strides.has_value()) {
*req->mutable_byte_strides() = ToByteStridesProto(*byte_strides);
}

TF_ASSIGN_OR_RETURN(
auto response,
rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await());
const ArrayHandle handle{response->array_handle()};

if (on_done_with_host_buffer != nullptr) {
std::move(on_done_with_host_buffer)();
ArrayHandle arr_handle;
if (GetGlobalClientFlags()->synchronous_host_buffer_store ||
rpc_helper->version().protocol_version() < 10) {
TF_ASSIGN_OR_RETURN(
auto resp, rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await());
arr_handle.handle = resp->array_handle();
} else {
rpc_helper->MakeArrayFromHostBuffer(std::move(req))
.OnReady(
[host_buffer_handle](
absl::StatusOr<std::shared_ptr<MakeArrayFromHostBufferResponse>>
resp) {
if (resp.ok()) {
CHECK_EQ(resp.value()->array_handle(), host_buffer_handle);
} else {
LOG(ERROR) << "In background MakeArrayFromHostBuffer: "
<< resp.status();
}
});
arr_handle.handle = host_buffer_handle;
}

return tsl::RCReference<xla::ifrt::Array>(
tsl::MakeRef<Array>(client, std::move(rpc_helper), dtype,
std::move(shape), std::move(sharding), handle));
std::move(shape), std::move(sharding), arr_handle));
}

void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) {
Expand Down Expand Up @@ -355,11 +425,11 @@ Future<> Array::CopyToStringHostBuffer(
"Byte strides are not supported for string arrays."));
}

auto host_buffer_store = rpc_helper_->host_buffer_store();
const uint64_t host_buffer_handle = host_buffer_store->NextHandle();
const uint64_t host_buffer_handle = rpc_helper_->NextHandle();
req->set_host_buffer_handle(host_buffer_handle);
auto promise = Future<>::CreatePromise();
auto on_ready = [promise, host_buffer_store = std::move(host_buffer_store),
auto on_ready = [promise,
host_buffer_store = rpc_helper_->host_buffer_store(),
host_buffer_handle,
dst_buffer = static_cast<absl::Cord*>(data)](
absl::StatusOr<std::shared_ptr<CopyToHostBufferResponse>>
Expand Down Expand Up @@ -415,8 +485,7 @@ Future<> Array::CopyToHostBuffer(
if (byte_strides.has_value()) {
*req->mutable_byte_strides() = ToByteStridesProto(*byte_strides);
}
const uint64_t host_buffer_handle =
rpc_helper_->host_buffer_store()->NextHandle();
const uint64_t host_buffer_handle = rpc_helper_->NextHandle();
req->set_host_buffer_handle(host_buffer_handle);

auto promise = Future<>::CreatePromise();
Expand Down
13 changes: 6 additions & 7 deletions xla/python/ifrt_proxy/client/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ absl::StatusOr<absl::Cord> ExecuteLoadedHostCallback(
// Same as `ExecuteLoadedHostCallback`, except that it uses host buffer store to
// retrieve operands and store results.
absl::StatusOr<uint64_t> PrepareAndExecuteLoadedHostCallback(
ClientHostBufferStore* host_buffer_store,
xla::ifrt::LoadedHostCallback* loaded_host_callback,
RpcHelper* rpc_helper, xla::ifrt::LoadedHostCallback* loaded_host_callback,
uint64_t operand_handle) {
ClientHostBufferStore* host_buffer_store =
rpc_helper->host_buffer_store().get();
TF_ASSIGN_OR_RETURN(absl::Cord operands,
host_buffer_store->Lookup(operand_handle).Await());
absl::Cleanup cleanup = [&]() {
Expand All @@ -172,7 +173,7 @@ absl::StatusOr<uint64_t> PrepareAndExecuteLoadedHostCallback(
absl::Cord results,
ExecuteLoadedHostCallback(loaded_host_callback, std::move(operands)));

const uint64_t result_handle = host_buffer_store->NextHandle();
const uint64_t result_handle = rpc_helper->NextHandle();
TF_RETURN_IF_ERROR(host_buffer_store->Store(result_handle, results).Await());
return result_handle;
}
Expand Down Expand Up @@ -517,8 +518,7 @@ void LoadedExecutable::PollLoadedHostCallback(
auto f = [rpc_helper = rpc_helper_, handle,
loaded_host_callback = std::move(loaded_host_callback)]() {
while (true) {
const uint64_t operand_handle =
rpc_helper->host_buffer_store()->NextHandle();
const uint64_t operand_handle = rpc_helper->NextHandle();

auto poll_req = std::make_unique<LoadedHostCallbackPollRequest>();
poll_req->set_loaded_host_callback_handle(handle);
Expand All @@ -543,8 +543,7 @@ void LoadedExecutable::PollLoadedHostCallback(

absl::StatusOr<uint64_t> result_handle =
PrepareAndExecuteLoadedHostCallback(
rpc_helper->host_buffer_store().get(), loaded_host_callback.get(),
operand_handle);
rpc_helper.get(), loaded_host_callback.get(), operand_handle);
if (result_handle.ok()) {
ret_req->set_result_host_buffer_handle(*result_handle);
} else {
Expand Down
Loading
Loading