Skip to content

Commit

Permalink
IFRT proxy: asynchronous and faster MakeArrayFromHostBuffer
Browse files Browse the repository at this point in the history
Note: I use the term 'control path' to refer to everything except `HostBufferStore.Store()` and `HostBufferStore.Lookup()` operations.

This CL improves performance with the following changes:
- The client manufactures array handles for `MakeArrayFromHostBufferRequest` (instead of the server generating them) and returns to the caller immediately after sending the request. Since ordering is maintained for control path requests across the proxy, future operations on the array do not require any special handling.
- The data-path `HostBufferStoreRequest` that corresponds to a `MakeArrayFromHostBufferRequest` is not ordered by the client (before or after) the `MakeArrayFromHostBufferRequest`. On the server-side, the loop that handles control path requests, when it sees a `MakeArrayFromHostBufferRequest`, blocks until the corresponding `HostBufferStoreRequest` is processed.
- The data-path (`HostBufferStore` implementation) and control path now use different gRPC channels.
- Resulting performance: BM_HostToDeviceAsync/1M/2k results in more than 3 GB/s, making it bottlenecked by gRPC `stream.Write()` latency. BM_HostToDeviceAsync/1K/98k (~4MB/s) was already bottlenecked by gRPC `stream.Write()` latency.

This CL also:
- Adds more XProf tracemes
- Introduces `global_flags.h` and `global_flags_google.cc` so we can conveniently use command-line flags in the proxy-client. This may not be ideal from a clean-code perspective, but makes it much easier to develop and debug the client.

PiperOrigin-RevId: 693790827
  • Loading branch information
Google-ML-Automation committed Nov 15, 2024
1 parent 4d5f691 commit ee0ba3b
Show file tree
Hide file tree
Showing 33 changed files with 764 additions and 252 deletions.
39 changes: 33 additions & 6 deletions xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cc_library(
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:unbounded_work_queue",
"@tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -97,6 +98,7 @@ cc_library(
":host_buffer",
"//xla/python/ifrt",
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"//xla/python/ifrt_proxy/common:prof_util",
"//xla/python/ifrt_proxy/common:test_utils",
"//xla/python/ifrt_proxy/common:types",
"//xla/tsl/profiler/utils:xplane_schema",
Expand Down Expand Up @@ -226,6 +228,7 @@ cc_library(
srcs = ["array.cc"],
hdrs = ["array.h"],
deps = [
":global_flags",
":rpc_helper",
"//xla:status_macros",
"//xla/python/ifrt",
Expand All @@ -237,6 +240,7 @@ cc_library(
"//xla/tsl/concurrency:ref_count",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -286,7 +290,6 @@ cc_library(
"//xla/python/ifrt",
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

Expand Down Expand Up @@ -401,7 +404,6 @@ cc_library(
hdrs = ["host_buffer.h"],
deps = [
"//xla/python/ifrt",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
Expand All @@ -415,8 +417,6 @@ cc_library(
deps = [
":host_buffer",
"//xla/python/ifrt",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_googletest//:gtest",
Expand All @@ -433,14 +433,14 @@ cc_library(
"//xla/python/ifrt",
"//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto",
"//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc",
"//xla/python/ifrt_proxy/common:prof_util",
"//xla/tsl/protobuf:status_proto_cc",
"@com_github_grpc_grpc//:grpc++",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:unbounded_work_queue",
],
)
Expand All @@ -450,6 +450,7 @@ cc_library(
srcs = ["grpc_client.cc"],
deps = [
":client",
":global_flags",
":grpc_client_session",
":grpc_host_buffer",
":registry",
Expand All @@ -464,7 +465,6 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/log:log_entry",
"@com_google_absl//absl/log:log_sink",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
Expand All @@ -480,9 +480,11 @@ cc_library(
srcs = ["registry.cc"],
hdrs = ["registry.h"],
deps = [
":global_flags",
"//xla/python/ifrt",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -574,3 +576,28 @@ cc_library(
"@tsl//tsl/platform:statusor",
],
)

# Export headers referenced by the google-internal-version of global_flags.
exports_files(
["global_flags.h"],
visibility = if_google(
["//xla/python/ifrt_proxy/client/google:__pkg__"],
["//visibility:private"],
),
)

cc_library(
name = "global_flags_oss",
srcs = [
"global_flags.h",
"global_flags_oss.cc",
],
visibility = ["//visibility:private"],
deps = ["@com_google_absl//absl/base:no_destructor"],
)

cc_library(
name = "global_flags",
hdrs = ["global_flags.h"],
deps = [if_google("//xla/python/ifrt_proxy/client/google:global_flags_google", ":global_flags_oss")],
)
131 changes: 100 additions & 31 deletions xla/python/ifrt_proxy/client/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "xla/python/ifrt_proxy/client/array.h"

#include <cstdint>
#include <cstring>
#include <functional>
#include <memory>
#include <optional>
Expand All @@ -23,6 +24,7 @@
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand All @@ -41,6 +43,7 @@
#include "xla/python/ifrt/remap_plan.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/ifrt_proxy/client/global_flags.h"
#include "xla/python/ifrt_proxy/client/rpc_helper.h"
#include "xla/python/ifrt_proxy/common/array_util.h"
#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h"
Expand All @@ -57,62 +60,129 @@ namespace proxy {

char Array::ID = 0;

using HostBufferSemantics = ::xla::ifrt::Client::HostBufferSemantics;

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>
Array::MakeArrayFromHostBuffer(
xla::ifrt::Client* client, std::shared_ptr<RpcHelper> rpc_helper,
const void* data, DType dtype, Shape shape,
std::optional<absl::Span<const int64_t>> byte_strides,
std::shared_ptr<const Sharding> sharding,
xla::ifrt::Client::HostBufferSemantics semantics,
std::shared_ptr<const Sharding> sharding, HostBufferSemantics semantics,
std::function<void()> on_done_with_host_buffer) {
const uint64_t host_buffer_handle =
rpc_helper->host_buffer_store()->NextHandle();

if (dtype.kind() == DType::kString) {
absl::string_view mem_region;
if (dtype.kind() != DType::kString) {
TF_ASSIGN_OR_RETURN(
auto array_mem_region,
ArrayMemRegion::FromZerothElementPointer(
/*zeroth_element=*/data, dtype, shape, byte_strides));
mem_region = array_mem_region.mem_region();
} else {
// DType::kString
if (rpc_helper->version().protocol_version() < 9) {
return absl::UnimplementedError(
"String arrays are not supported in ifrt-proxy version < 9");
}
TF_ASSIGN_OR_RETURN(
const std::string serialized_string_buffer,
std::shared_ptr<std::string> owned_data,
SerializeStringHostBuffer(absl::MakeConstSpan(
static_cast<const absl::Cord*>(data), shape.num_elements())));
mem_region = *owned_data;
semantics = HostBufferSemantics::kImmutableUntilTransferCompletes;
std::function<void()> on_done(std::move(on_done_with_host_buffer));
on_done_with_host_buffer = [owned_data = std::move(owned_data),
on_done = std::move(on_done)]() {
if (on_done) {
std::move(on_done)();
}
};
}

const uint64_t host_buffer_handle = rpc_helper->NextHandle();

if (GetGlobalClientFlags()->synchronous_host_buffer_store ||
rpc_helper->version().protocol_version() < 10) {
// Synchronously send data and await.
TF_RETURN_IF_ERROR(rpc_helper->host_buffer_store()
->Store(host_buffer_handle,
absl::string_view(serialized_string_buffer))
->Store(host_buffer_handle, mem_region)
.Await());
if (on_done_with_host_buffer != nullptr) {
std::move(on_done_with_host_buffer)();
}
} else {
TF_ASSIGN_OR_RETURN(
const auto array_mem_region,
ArrayMemRegion::FromZerothElementPointer(
/*zeroth_element=*/data, dtype, shape, byte_strides));
TF_RETURN_IF_ERROR(
rpc_helper->host_buffer_store()
->Store(host_buffer_handle, array_mem_region.mem_region())
.Await());
// Asynchronously send data.

if (semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
auto alloc = std::make_shared<char[]>(mem_region.size());
memcpy(&alloc[0], mem_region.data(), mem_region.size());
mem_region = absl::string_view(&alloc[0], mem_region.size());
if (on_done_with_host_buffer != nullptr) {
std::move(on_done_with_host_buffer)();
}
on_done_with_host_buffer = [alloc = std::move(alloc)]() {};
}

// If the async-send results in an error, ignoring it may mean that the
// control-path hangs forever. Instead, we explicitly ensure the
// control-path gets disconnected (and so the entire session ends).
//
// While there are more fine-grained approaches to handle errors, we do not
// expect an error except for one that indicates being already disconnected
// from the server.
rpc_helper->host_buffer_store()
->Store(host_buffer_handle, mem_region)
.OnReady([on_done = std::move(on_done_with_host_buffer),
rpc_helper = std::weak_ptr<RpcHelper>(rpc_helper)](
absl::Status s) mutable {
if (!s.ok()) {
LOG(WARNING) << "Handling error in background data-transfer by "
<< "disconnecting from server (if not already "
<< "disconnected), error: " << s;
if (auto locked = rpc_helper.lock()) {
locked->Disconnect();
}
};
if (on_done != nullptr) {
std::move(on_done)();
}
});
}

auto req = std::make_unique<MakeArrayFromHostBufferRequest>();
req->set_host_buffer_handle(host_buffer_handle);
// Reuse the host_buffer_handle as also the client-manufactured
// array_handle.
req->set_array_handle(host_buffer_handle);
*req->mutable_dtype() = dtype.ToProto();
*req->mutable_shape() = shape.ToProto();
TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), sharding->ToProto());
if (byte_strides.has_value()) {
*req->mutable_byte_strides() = ToByteStridesProto(*byte_strides);
}

TF_ASSIGN_OR_RETURN(
auto response,
rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await());
const ArrayHandle handle{response->array_handle()};

if (on_done_with_host_buffer != nullptr) {
std::move(on_done_with_host_buffer)();
ArrayHandle arr_handle;
if (GetGlobalClientFlags()->synchronous_host_buffer_store ||
rpc_helper->version().protocol_version() < 10) {
TF_ASSIGN_OR_RETURN(
auto resp, rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await());
arr_handle.handle = resp->array_handle();
} else {
rpc_helper->MakeArrayFromHostBuffer(std::move(req))
.OnReady(
[host_buffer_handle](
absl::StatusOr<std::shared_ptr<MakeArrayFromHostBufferResponse>>
resp) {
if (resp.ok()) {
CHECK_EQ(resp.value()->array_handle(), host_buffer_handle);
} else {
LOG(ERROR) << "In background MakeArrayFromHostBuffer: "
<< resp.status();
}
});
arr_handle.handle = host_buffer_handle;
}

return tsl::RCReference<xla::ifrt::Array>(
tsl::MakeRef<Array>(client, std::move(rpc_helper), dtype,
std::move(shape), std::move(sharding), handle));
std::move(shape), std::move(sharding), arr_handle));
}

void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) {
Expand Down Expand Up @@ -355,11 +425,11 @@ Future<> Array::CopyToStringHostBuffer(
"Byte strides are not supported for string arrays."));
}

auto host_buffer_store = rpc_helper_->host_buffer_store();
const uint64_t host_buffer_handle = host_buffer_store->NextHandle();
const uint64_t host_buffer_handle = rpc_helper_->NextHandle();
req->set_host_buffer_handle(host_buffer_handle);
auto promise = Future<>::CreatePromise();
auto on_ready = [promise, host_buffer_store = std::move(host_buffer_store),
auto on_ready = [promise,
host_buffer_store = rpc_helper_->host_buffer_store(),
host_buffer_handle,
dst_buffer = static_cast<absl::Cord*>(data)](
absl::StatusOr<std::shared_ptr<CopyToHostBufferResponse>>
Expand Down Expand Up @@ -415,8 +485,7 @@ Future<> Array::CopyToHostBuffer(
if (byte_strides.has_value()) {
*req->mutable_byte_strides() = ToByteStridesProto(*byte_strides);
}
const uint64_t host_buffer_handle =
rpc_helper_->host_buffer_store()->NextHandle();
const uint64_t host_buffer_handle = rpc_helper_->NextHandle();
req->set_host_buffer_handle(host_buffer_handle);

auto promise = Future<>::CreatePromise();
Expand Down
13 changes: 6 additions & 7 deletions xla/python/ifrt_proxy/client/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ absl::StatusOr<absl::Cord> ExecuteLoadedHostCallback(
// Same as `ExecuteLoadedHostCallback`, except that it uses host buffer store to
// retrieve operands and store results.
absl::StatusOr<uint64_t> PrepareAndExecuteLoadedHostCallback(
ClientHostBufferStore* host_buffer_store,
xla::ifrt::LoadedHostCallback* loaded_host_callback,
RpcHelper* rpc_helper, xla::ifrt::LoadedHostCallback* loaded_host_callback,
uint64_t operand_handle) {
ClientHostBufferStore* host_buffer_store =
rpc_helper->host_buffer_store().get();
TF_ASSIGN_OR_RETURN(absl::Cord operands,
host_buffer_store->Lookup(operand_handle).Await());
absl::Cleanup cleanup = [&]() {
Expand All @@ -172,7 +173,7 @@ absl::StatusOr<uint64_t> PrepareAndExecuteLoadedHostCallback(
absl::Cord results,
ExecuteLoadedHostCallback(loaded_host_callback, std::move(operands)));

const uint64_t result_handle = host_buffer_store->NextHandle();
const uint64_t result_handle = rpc_helper->NextHandle();
TF_RETURN_IF_ERROR(host_buffer_store->Store(result_handle, results).Await());
return result_handle;
}
Expand Down Expand Up @@ -517,8 +518,7 @@ void LoadedExecutable::PollLoadedHostCallback(
auto f = [rpc_helper = rpc_helper_, handle,
loaded_host_callback = std::move(loaded_host_callback)]() {
while (true) {
const uint64_t operand_handle =
rpc_helper->host_buffer_store()->NextHandle();
const uint64_t operand_handle = rpc_helper->NextHandle();

auto poll_req = std::make_unique<LoadedHostCallbackPollRequest>();
poll_req->set_loaded_host_callback_handle(handle);
Expand All @@ -543,8 +543,7 @@ void LoadedExecutable::PollLoadedHostCallback(

absl::StatusOr<uint64_t> result_handle =
PrepareAndExecuteLoadedHostCallback(
rpc_helper->host_buffer_store().get(), loaded_host_callback.get(),
operand_handle);
rpc_helper.get(), loaded_host_callback.get(), operand_handle);
if (result_handle.ok()) {
ret_req->set_result_host_buffer_handle(*result_handle);
} else {
Expand Down
Loading

0 comments on commit ee0ba3b

Please sign in to comment.