diff --git a/xla/python/ifrt_proxy/client/BUILD b/xla/python/ifrt_proxy/client/BUILD index bfbf24d0a8d9d..6dee7ac9a5087 100644 --- a/xla/python/ifrt_proxy/client/BUILD +++ b/xla/python/ifrt_proxy/client/BUILD @@ -49,6 +49,7 @@ cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:unbounded_work_queue", + "@tsl//tsl/profiler/lib:traceme", ], ) @@ -97,6 +98,7 @@ cc_library( ":host_buffer", "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:prof_util", "//xla/python/ifrt_proxy/common:test_utils", "//xla/python/ifrt_proxy/common:types", "//xla/tsl/profiler/utils:xplane_schema", @@ -226,6 +228,7 @@ cc_library( srcs = ["array.cc"], hdrs = ["array.h"], deps = [ + ":global_flags", ":rpc_helper", "//xla:status_macros", "//xla/python/ifrt", @@ -237,6 +240,7 @@ cc_library( "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -286,7 +290,6 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", ], ) @@ -401,7 +404,6 @@ cc_library( hdrs = ["host_buffer.h"], deps = [ "//xla/python/ifrt", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", @@ -415,8 +417,6 @@ cc_library( deps = [ ":host_buffer", "//xla/python/ifrt", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_googletest//:gtest", @@ -433,6 +433,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:prof_util", "//xla/tsl/protobuf:status_proto_cc", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", @@ -440,7 +441,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/platform:env", "@tsl//tsl/platform:unbounded_work_queue", ], ) @@ -450,6 +450,7 @@ cc_library( srcs = ["grpc_client.cc"], deps = [ ":client", + ":global_flags", ":grpc_client_session", ":grpc_host_buffer", ":registry", @@ -464,7 +465,6 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/log:log_entry", - "@com_google_absl//absl/log:log_sink", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -480,9 +480,11 @@ cc_library( srcs = ["registry.cc"], hdrs = ["registry.h"], deps = [ + ":global_flags", "//xla/python/ifrt", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -574,3 +576,28 @@ cc_library( "@tsl//tsl/platform:statusor", ], ) + +# Export headers referenced by the google-internal-version of global_flags. +exports_files( + ["global_flags.h"], + visibility = if_google( + ["//xla/python/ifrt_proxy/client/google:__pkg__"], + ["//visibility:private"], + ), +) + +cc_library( + name = "global_flags_oss", + srcs = [ + "global_flags.h", + "global_flags_oss.cc", + ], + visibility = ["//visibility:private"], + deps = ["@com_google_absl//absl/base:no_destructor"], +) + +cc_library( + name = "global_flags", + hdrs = ["global_flags.h"], + deps = [if_google("//xla/python/ifrt_proxy/client/google:global_flags_google", ":global_flags_oss")], +) diff --git a/xla/python/ifrt_proxy/client/array.cc b/xla/python/ifrt_proxy/client/array.cc index 3195793dca80c..d37ec5ab762fd 100644 --- a/xla/python/ifrt_proxy/client/array.cc +++ b/xla/python/ifrt_proxy/client/array.cc @@ -15,6 +15,7 @@ #include "xla/python/ifrt_proxy/client/array.h" #include +#include #include #include #include @@ -23,6 +24,7 @@ #include #include "absl/cleanup/cleanup.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -41,6 +43,7 @@ #include "xla/python/ifrt/remap_plan.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/global_flags.h" #include "xla/python/ifrt_proxy/client/rpc_helper.h" #include "xla/python/ifrt_proxy/common/array_util.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" @@ -57,43 +60,98 @@ namespace proxy { char Array::ID = 0; +using HostBufferSemantics = ::xla::ifrt::Client::HostBufferSemantics; + absl::StatusOr> Array::MakeArrayFromHostBuffer( xla::ifrt::Client* client, std::shared_ptr rpc_helper, const void* data, DType dtype, Shape shape, std::optional> byte_strides, - std::shared_ptr sharding, - xla::ifrt::Client::HostBufferSemantics semantics, + std::shared_ptr sharding, HostBufferSemantics semantics, std::function on_done_with_host_buffer) { - const uint64_t host_buffer_handle = - rpc_helper->host_buffer_store()->NextHandle(); - - if (dtype.kind() == DType::kString) { + absl::string_view mem_region; + if (dtype.kind() != DType::kString) { + TF_ASSIGN_OR_RETURN( + auto array_mem_region, + ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data, dtype, shape, byte_strides)); + mem_region = array_mem_region.mem_region(); + } else { + // DType::kString if (rpc_helper->version().protocol_version() < 9) { return absl::UnimplementedError( "String arrays are not supported in ifrt-proxy version < 9"); } TF_ASSIGN_OR_RETURN( - const std::string serialized_string_buffer, + std::shared_ptr owned_data, SerializeStringHostBuffer(absl::MakeConstSpan( static_cast(data), shape.num_elements()))); + mem_region = *owned_data; + semantics = HostBufferSemantics::kImmutableUntilTransferCompletes; + std::function on_done(std::move(on_done_with_host_buffer)); + on_done_with_host_buffer = [owned_data = std::move(owned_data), + on_done = std::move(on_done)]() { + if (on_done) { + std::move(on_done)(); + } + }; + } + + const uint64_t host_buffer_handle = rpc_helper->NextHandle(); + + if (GetGlobalClientFlags()->synchronous_host_buffer_store || + rpc_helper->version().protocol_version() < 10) { + // Synchronously send data and await. TF_RETURN_IF_ERROR(rpc_helper->host_buffer_store() - ->Store(host_buffer_handle, - absl::string_view(serialized_string_buffer)) + ->Store(host_buffer_handle, mem_region) .Await()); + if (on_done_with_host_buffer != nullptr) { + std::move(on_done_with_host_buffer)(); + } } else { - TF_ASSIGN_OR_RETURN( - const auto array_mem_region, - ArrayMemRegion::FromZerothElementPointer( - /*zeroth_element=*/data, dtype, shape, byte_strides)); - TF_RETURN_IF_ERROR( - rpc_helper->host_buffer_store() - ->Store(host_buffer_handle, array_mem_region.mem_region()) - .Await()); + // Asynchronously send data. + + if (semantics == HostBufferSemantics::kImmutableOnlyDuringCall) { + auto alloc = std::make_shared(mem_region.size()); + memcpy(&alloc[0], mem_region.data(), mem_region.size()); + mem_region = absl::string_view(&alloc[0], mem_region.size()); + if (on_done_with_host_buffer != nullptr) { + std::move(on_done_with_host_buffer)(); + } + on_done_with_host_buffer = [alloc = std::move(alloc)]() {}; + } + + // If the async-send results in an error, ignoring it may mean that the + // control-path hangs forever. Instead, we explicitly ensure the + // control-path gets disconnected (and so the entire session ends). + // + // While there are more fine-grained approaches to handle errors, we do not + // expect an error except for one that indicates being already disconnected + // from the server. + rpc_helper->host_buffer_store() + ->Store(host_buffer_handle, mem_region) + .OnReady([on_done = std::move(on_done_with_host_buffer), + rpc_helper = std::weak_ptr(rpc_helper)]( + absl::Status s) mutable { + if (!s.ok()) { + LOG(WARNING) << "Handling error in background data-transfer by " + << "disconnecting from server (if not already " + << "disconnected), error: " << s; + if (auto locked = rpc_helper.lock()) { + locked->Disconnect(); + } + }; + if (on_done != nullptr) { + std::move(on_done)(); + } + }); } auto req = std::make_unique(); req->set_host_buffer_handle(host_buffer_handle); + // Reuse the host_buffer_handle as also the client-manufactured + // array_handle. + req->set_array_handle(host_buffer_handle); *req->mutable_dtype() = dtype.ToProto(); *req->mutable_shape() = shape.ToProto(); TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), sharding->ToProto()); @@ -101,18 +159,30 @@ Array::MakeArrayFromHostBuffer( *req->mutable_byte_strides() = ToByteStridesProto(*byte_strides); } - TF_ASSIGN_OR_RETURN( - auto response, - rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await()); - const ArrayHandle handle{response->array_handle()}; - - if (on_done_with_host_buffer != nullptr) { - std::move(on_done_with_host_buffer)(); + ArrayHandle arr_handle; + if (GetGlobalClientFlags()->synchronous_host_buffer_store || + rpc_helper->version().protocol_version() < 10) { + TF_ASSIGN_OR_RETURN( + auto resp, rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await()); + arr_handle.handle = resp->array_handle(); + } else { + rpc_helper->MakeArrayFromHostBuffer(std::move(req)) + .OnReady( + [host_buffer_handle]( + absl::StatusOr> + resp) { + if (resp.ok()) { + CHECK_EQ(resp.value()->array_handle(), host_buffer_handle); + } else { + LOG(ERROR) << "In background MakeArrayFromHostBuffer: " + << resp.status(); + } + }); + arr_handle.handle = host_buffer_handle; } - return tsl::RCReference( tsl::MakeRef(client, std::move(rpc_helper), dtype, - std::move(shape), std::move(sharding), handle)); + std::move(shape), std::move(sharding), arr_handle)); } void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { @@ -355,11 +425,11 @@ Future<> Array::CopyToStringHostBuffer( "Byte strides are not supported for string arrays.")); } - auto host_buffer_store = rpc_helper_->host_buffer_store(); - const uint64_t host_buffer_handle = host_buffer_store->NextHandle(); + const uint64_t host_buffer_handle = rpc_helper_->NextHandle(); req->set_host_buffer_handle(host_buffer_handle); auto promise = Future<>::CreatePromise(); - auto on_ready = [promise, host_buffer_store = std::move(host_buffer_store), + auto on_ready = [promise, + host_buffer_store = rpc_helper_->host_buffer_store(), host_buffer_handle, dst_buffer = static_cast(data)]( absl::StatusOr> @@ -415,8 +485,7 @@ Future<> Array::CopyToHostBuffer( if (byte_strides.has_value()) { *req->mutable_byte_strides() = ToByteStridesProto(*byte_strides); } - const uint64_t host_buffer_handle = - rpc_helper_->host_buffer_store()->NextHandle(); + const uint64_t host_buffer_handle = rpc_helper_->NextHandle(); req->set_host_buffer_handle(host_buffer_handle); auto promise = Future<>::CreatePromise(); diff --git a/xla/python/ifrt_proxy/client/executable.cc b/xla/python/ifrt_proxy/client/executable.cc index f284a5bbd75af..349f1ef667ef9 100644 --- a/xla/python/ifrt_proxy/client/executable.cc +++ b/xla/python/ifrt_proxy/client/executable.cc @@ -155,9 +155,10 @@ absl::StatusOr ExecuteLoadedHostCallback( // Same as `ExecuteLoadedHostCallback`, except that it uses host buffer store to // retrieve operands and store results. absl::StatusOr PrepareAndExecuteLoadedHostCallback( - ClientHostBufferStore* host_buffer_store, - xla::ifrt::LoadedHostCallback* loaded_host_callback, + RpcHelper* rpc_helper, xla::ifrt::LoadedHostCallback* loaded_host_callback, uint64_t operand_handle) { + ClientHostBufferStore* host_buffer_store = + rpc_helper->host_buffer_store().get(); TF_ASSIGN_OR_RETURN(absl::Cord operands, host_buffer_store->Lookup(operand_handle).Await()); absl::Cleanup cleanup = [&]() { @@ -172,7 +173,7 @@ absl::StatusOr PrepareAndExecuteLoadedHostCallback( absl::Cord results, ExecuteLoadedHostCallback(loaded_host_callback, std::move(operands))); - const uint64_t result_handle = host_buffer_store->NextHandle(); + const uint64_t result_handle = rpc_helper->NextHandle(); TF_RETURN_IF_ERROR(host_buffer_store->Store(result_handle, results).Await()); return result_handle; } @@ -517,8 +518,7 @@ void LoadedExecutable::PollLoadedHostCallback( auto f = [rpc_helper = rpc_helper_, handle, loaded_host_callback = std::move(loaded_host_callback)]() { while (true) { - const uint64_t operand_handle = - rpc_helper->host_buffer_store()->NextHandle(); + const uint64_t operand_handle = rpc_helper->NextHandle(); auto poll_req = std::make_unique(); poll_req->set_loaded_host_callback_handle(handle); @@ -543,8 +543,7 @@ void LoadedExecutable::PollLoadedHostCallback( absl::StatusOr result_handle = PrepareAndExecuteLoadedHostCallback( - rpc_helper->host_buffer_store().get(), loaded_host_callback.get(), - operand_handle); + rpc_helper.get(), loaded_host_callback.get(), operand_handle); if (result_handle.ok()) { ret_req->set_result_host_buffer_handle(*result_handle); } else { diff --git a/xla/python/ifrt_proxy/client/global_flags.h b/xla/python/ifrt_proxy/client/global_flags.h new file mode 100644 index 0000000000000..5dae020b32f4d --- /dev/null +++ b/xla/python/ifrt_proxy/client/global_flags.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GLOBAL_FLAGS_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_GLOBAL_FLAGS_H_ + +#include + +namespace xla { +namespace ifrt { +namespace proxy { + +// Flags that are set based on command-line options or environment variables. +// As of November 2024, the OSSed code does not actually have any mechanism +// to configure these flags (global_flags_oss.cc has default values that are +// compile-time constants); Google-internal code allows it to be configured from +// command-line options. +struct GlobalClientFlags { + // Setting to true reverts to implementation from before Nov 2024, where + // host buffer stores were issued synchronously and waited upon. + // TODO(madthanu): Remove flag once there is confidence that the asynchronous + // codepath works well. + bool synchronous_host_buffer_store; +}; + +GlobalClientFlags* GetGlobalClientFlags(); + +inline std::ostream& operator<<(std::ostream& os, GlobalClientFlags flags) { + return os << "xla::ifrt::proxy::GlobalClientFlags{" + << "synchronous_host_buffer_store=" + << flags.synchronous_host_buffer_store << "}"; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_GLOBAL_FLAGS_H_ diff --git a/xla/python/ifrt_proxy/client/global_flags_oss.cc b/xla/python/ifrt_proxy/client/global_flags_oss.cc new file mode 100644 index 0000000000000..cfe35024a2176 --- /dev/null +++ b/xla/python/ifrt_proxy/client/global_flags_oss.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/base/no_destructor.h" +#include "xla/python/ifrt_proxy/client/global_flags.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +static GlobalClientFlags DefaultGlobalClientFlags() { + GlobalClientFlags result; + result.synchronous_host_buffer_store = false; + return result; +}; + +GlobalClientFlags* GetGlobalClientFlags() { + static absl::NoDestructor result( + DefaultGlobalClientFlags()); + return result.get(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/grpc_client.cc b/xla/python/ifrt_proxy/client/grpc_client.cc index 590fff2a5c3e1..bb1edc754438b 100644 --- a/xla/python/ifrt_proxy/client/grpc_client.cc +++ b/xla/python/ifrt_proxy/client/grpc_client.cc @@ -20,7 +20,6 @@ #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/log/log_entry.h" -#include "absl/log/log_sink.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -30,6 +29,7 @@ #include "xla/pjrt/distributed/util.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client.h" +#include "xla/python/ifrt_proxy/client/global_flags.h" #include "xla/python/ifrt_proxy/client/grpc_client_session.h" #include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" #include "xla/python/ifrt_proxy/client/registry.h" @@ -66,7 +66,7 @@ absl::StatusOr> AttemptConnection( // TODO(b/266635130): Move gRPC stub creation to be outside of `Client` so // that we can pass mock `ClientSession` to the client. - auto stub = CreateGrpcStub(server_address); + auto control_path_stub = CreateGrpcStub(server_address); auto session_disconnect_cb = [init_response = Future>( @@ -102,16 +102,16 @@ absl::StatusOr> AttemptConnection( ::grpc::ClientContext context; GrpcGetVersionResponse response; - TF_RETURN_IF_ERROR( - xla::FromGrpcStatus(stub->GetVersion(&context, request, &response))); + TF_RETURN_IF_ERROR(xla::FromGrpcStatus( + control_path_stub->GetVersion(&context, request, &response))); CHECK_GE(response.version().protocol_version(), kClientMinVersion); CHECK_LE(response.version().protocol_version(), kClientMaxVersion); *metadata.mutable_version() = response.version(); } - auto session = - GrpcClientSession::Create(stub, metadata, session_disconnect_cb); + auto session = GrpcClientSession::Create(control_path_stub, metadata, + session_disconnect_cb); rpc_helper = std::make_unique(metadata.version(), std::move(session)); @@ -129,8 +129,15 @@ absl::StatusOr> AttemptConnection( auto init_response, Future>(init_response_promise).Await()); + bool reuse_control_path_stub_for_data_path = + GetGlobalClientFlags()->synchronous_host_buffer_store || + (metadata.version().protocol_version() < 10); + auto data_path_stub = reuse_control_path_stub_for_data_path + ? control_path_stub + : CreateGrpcStub(server_address); + auto host_buffer_store = std::make_unique( - stub, metadata.version(), init_response->session_id()); + data_path_stub, metadata.version(), init_response->session_id()); rpc_helper->set_host_buffer_store(std::move(host_buffer_store)); return Client::Create(std::move(rpc_helper), std::move(*init_response)); diff --git a/xla/python/ifrt_proxy/client/grpc_client_session.cc b/xla/python/ifrt_proxy/client/grpc_client_session.cc index b555bc62d2939..456264c66f1f2 100644 --- a/xla/python/ifrt_proxy/client/grpc_client_session.cc +++ b/xla/python/ifrt_proxy/client/grpc_client_session.cc @@ -48,6 +48,7 @@ #include "tsl/platform/logging.h" #include "tsl/platform/threadpool.h" #include "tsl/platform/unbounded_work_queue.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace ifrt { @@ -156,6 +157,7 @@ absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, CHECK_EQ(req->mutable_request_metadata()->op_id(), 0); req->mutable_request_metadata()->set_op_id(op_id); + tsl::profiler::TraceMe t("grpc_stream_write"); if (!stream_->Write(*req)) { CHECK(response_callbacks_->Pop(op_id).has_value()); return absl::UnknownError("GrpcClientSession: writing to stream failed."); @@ -249,6 +251,7 @@ std::shared_ptr CreateGrpcStub( // model compilation. args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true); std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel( std::string(server_address), GetClientCredentials(), args); VLOG(0) << " Established channel."; diff --git a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc index b80f84b059377..40fa159847c0d 100644 --- a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc +++ b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -14,7 +14,6 @@ #include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" -#include #include #include #include @@ -33,8 +32,8 @@ #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/prof_util.h" #include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/env.h" #include "tsl/platform/unbounded_work_queue.h" namespace xla { @@ -43,59 +42,68 @@ namespace proxy { static constexpr int64_t kChunkSize = 1024 * 1024; +static void SetDataFromStringView(GrpcHostBufferStoreRequest& req, + absl::string_view data) { +#if defined(PLATFORM_GOOGLE) + req.set_alias_data(data); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + req.set_data(std::string(data)); +#endif +} + GrpcClientHostBufferStore::GrpcClientHostBufferStore( std::shared_ptr stub, IfrtProxyVersion version, uint64_t session_id) : stub_(std::move(stub)), version_(std::move(version)), session_id_(session_id), - lookup_work_queue_(std::make_unique( + work_queue_(std::make_unique( tsl::Env::Default(), "HostBufferStoreLookupsWorkQueue")) {} GrpcClientHostBufferStore::~GrpcClientHostBufferStore() { LOG(INFO) << "Waiting for destruction of HostBufferStoreLookupsWorkQueue..."; - lookup_work_queue_.reset(); + work_queue_.reset(); LOG(INFO) << "Destructed HostBufferStoreLookupsWorkQueue."; } -uint64_t GrpcClientHostBufferStore::NextHandle() { - return next_handle_.fetch_add(1, std::memory_order_relaxed); -} - Future<> GrpcClientHostBufferStore::Store(uint64_t handle, absl::string_view data) { - // The current implementation synchronously sends host buffer chunks. We may - // consider making it asynchronous if the caller can leverage such asynchrony. + auto promise = Future<>::CreatePromise(); - GrpcHostBufferStoreMetadata metadata; - metadata.set_session_id(session_id_); - metadata.set_handle(handle); - metadata.set_buffer_size(data.size()); + XFlowHelper flow("GrpcClientHostBufferStore::Store"); + flow.InstantActivity(); - ::grpc::ClientContext context; - context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin", - metadata.SerializeAsString()); + std::unique_ptr buffered_data; - GrpcHostBufferStoreResponse response; - auto writer = stub_->HostBufferStore(&context, &response); + work_queue_->Schedule([this, handle, promise, data, flow]() mutable -> void { + auto span = flow.Span(); + GrpcHostBufferStoreMetadata metadata; + metadata.set_session_id(session_id_); + metadata.set_handle(handle); + metadata.set_buffer_size(data.size()); - for (int64_t offset = 0; offset < data.size(); offset += kChunkSize) { - GrpcHostBufferStoreRequest request; -#if defined(PLATFORM_GOOGLE) - request.set_alias_data(data.substr(offset, kChunkSize)); -#else - // TODO(b/325306748): Find a way to not do a memory-copy. - request.set_data(std::string(data.substr(offset, kChunkSize))); -#endif - writer->Write(request); - } + ::grpc::ClientContext context; + context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin", + metadata.SerializeAsString()); - if (!writer->WritesDone()) { - return Future<>( - absl::InternalError("Failed to write all host buffer chunks")); - } + GrpcHostBufferStoreResponse response; + auto writer = stub_->HostBufferStore(&context, &response); - return Future<>(xla::FromGrpcStatus(writer->Finish())); + for (int64_t offset = 0; offset < data.size(); offset += kChunkSize) { + GrpcHostBufferStoreRequest request; + SetDataFromStringView(request, data.substr(offset, kChunkSize)); + writer->Write(request); + } + + if (!writer->WritesDone()) { + promise.Set( + absl::InternalError("Failed to write all host buffer chunks")); + } + + promise.Set(xla::FromGrpcStatus(writer->Finish())); + }); + return Future<>(promise); } Future<> GrpcClientHostBufferStore::Store(uint64_t handle, @@ -118,12 +126,7 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, for (absl::string_view chunk : data.Chunks()) { for (int64_t offset = 0; offset < chunk.size(); offset += kChunkSize) { GrpcHostBufferStoreRequest request; -#if defined(PLATFORM_GOOGLE) - request.set_alias_data(chunk.substr(offset, kChunkSize)); -#else - // TODO(b/325306748): Find a way to not do a memory-copy. - request.set_data(std::string(chunk.substr(offset, kChunkSize))); -#endif + SetDataFromStringView(request, chunk.substr(offset, kChunkSize)); writer->Write(request); } } @@ -138,7 +141,7 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, Future GrpcClientHostBufferStore::Lookup(uint64_t handle) { auto promise = Future::CreatePromise(); - lookup_work_queue_->Schedule([this, handle, promise]() mutable -> void { + work_queue_->Schedule([this, handle, promise]() mutable -> void { GrpcHostBufferLookupRequest request; request.set_handle(handle); request.set_session_id(session_id_); diff --git a/xla/python/ifrt_proxy/client/grpc_host_buffer.h b/xla/python/ifrt_proxy/client/grpc_host_buffer.h index 50ab06a7fd894..6a0c433f2e978 100644 --- a/xla/python/ifrt_proxy/client/grpc_host_buffer.h +++ b/xla/python/ifrt_proxy/client/grpc_host_buffer.h @@ -17,11 +17,9 @@ #ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ #define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ -#include #include #include -#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "xla/python/ifrt/future.h" @@ -43,7 +41,6 @@ class GrpcClientHostBufferStore : public ClientHostBufferStore { // Implements ClientHostBufferStore. - uint64_t NextHandle() override; Future<> Store(uint64_t handle, absl::string_view data) override; Future<> Store(uint64_t handle, const absl::Cord& data) override; Future Lookup(uint64_t handle) override; @@ -53,14 +50,13 @@ class GrpcClientHostBufferStore : public ClientHostBufferStore { const std::shared_ptr stub_; const IfrtProxyVersion version_; const uint64_t session_id_; - std::atomic next_handle_ = 0; - // Implementation note: `lookup_work_queue_` may have closures that invoke - // user-defined code. Each `Lookup()` call is associated with a scheduled - // closure, and the closure is used to first perform synchronous reads of the - // streaming RPC, and then to do `promise.Set()` for the Future returned to - // the caller. - std::unique_ptr lookup_work_queue_; + // Implementation note: `work_queue_` may have closures that invoke + // user-defined code. Each `Store()` and `Lookup()` call is associated with a + // scheduled closure, and the closure is used to first perform synchronous + // RPC reads or writes, and then to do `promise.Set()` for the Future returned + // to the caller. + std::unique_ptr work_queue_; }; } // namespace proxy diff --git a/xla/python/ifrt_proxy/client/host_buffer.h b/xla/python/ifrt_proxy/client/host_buffer.h index cf9ccfa7afe41..ce4706569ee46 100644 --- a/xla/python/ifrt_proxy/client/host_buffer.h +++ b/xla/python/ifrt_proxy/client/host_buffer.h @@ -32,8 +32,6 @@ class ClientHostBufferStore { public: virtual ~ClientHostBufferStore() = default; - virtual uint64_t NextHandle() = 0; - // Stores the data associated with the given handle. Returns an error if the // handle already exists. virtual Future<> Store(uint64_t handle, absl::string_view data) = 0; diff --git a/xla/python/ifrt_proxy/client/mock_host_buffer.h b/xla/python/ifrt_proxy/client/mock_host_buffer.h index f947f350d5694..ea10be8a16b84 100644 --- a/xla/python/ifrt_proxy/client/mock_host_buffer.h +++ b/xla/python/ifrt_proxy/client/mock_host_buffer.h @@ -20,7 +20,6 @@ #include #include -#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "xla/python/ifrt/future.h" @@ -32,7 +31,6 @@ namespace proxy { class MockClientHostBufferStore final : public ClientHostBufferStore { public: - MOCK_METHOD(uint64_t, NextHandle, (), (override)); MOCK_METHOD(Future<>, Store, (uint64_t handle, absl::string_view data), (override)); MOCK_METHOD(Future<>, Store, (uint64_t handle, const absl::Cord& data), diff --git a/xla/python/ifrt_proxy/client/registry.cc b/xla/python/ifrt_proxy/client/registry.cc index 11680771b8b49..e50d267b94e92 100644 --- a/xla/python/ifrt_proxy/client/registry.cc +++ b/xla/python/ifrt_proxy/client/registry.cc @@ -22,6 +22,7 @@ #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -29,6 +30,7 @@ #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/global_flags.h" namespace xla { namespace ifrt { @@ -77,6 +79,9 @@ absl::StatusOr> CreateClient( const absl::string_view transport_name = proxy_server_address.substr(0, pos); const absl::string_view address = proxy_server_address.substr(pos + 3); + LOG(INFO) << "Attempting to create IFRT proxy client with transport name " + << transport_name << " to address '" << address + << "' and with global client flags " << *GetGlobalClientFlags(); FactoryFn factory; { diff --git a/xla/python/ifrt_proxy/client/rpc_helper.cc b/xla/python/ifrt_proxy/client/rpc_helper.cc index 19998ffd34619..849e82a5a3571 100644 --- a/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -15,6 +15,7 @@ #include "xla/python/ifrt_proxy/client/rpc_helper.h" #include +#include #include #include #include @@ -36,15 +37,13 @@ #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/prof_util.h" #include "xla/python/ifrt_proxy/common/test_utils.h" #include "xla/python/ifrt_proxy/common/types.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/env.h" -#include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/profiler/lib/traceme_encode.h" namespace xla { namespace ifrt { @@ -52,67 +51,8 @@ namespace proxy { namespace { -using ::tsl::profiler::XFlow; - constexpr absl::Duration kPeriodicFlushInterval = absl::Microseconds(50); -// XFlowHelper makes it easier to create trace spans with a flow between them. -// Typical usage: -// -// XFlowHelper flow("my_request"); -// ... -// -// auto response_handler = [flow](ResponseMsg msg) { -// flow.InstantActivity(); -// LOG(INFO) << "Received response: " << msg; -// } -// -// { -// auto request_span = flow.Span(); -// auto request_protobuf = CreateRequestProtobuf(); -// transport.Send(request_protobuf, response_handler); -// } -// -// -class XFlowHelper { - public: - explicit XFlowHelper(absl::string_view name) - : xflow_id_(tsl::random::New64() >> 8 /*XFlow IDs are 56 bits*/), - name_(name) {} - - typedef enum { kSend, kRecv, kRecvSend } Direction; - - template - tsl::profiler::TraceMe Span() const { - return tsl::profiler::TraceMe([xflow_id = xflow_id_, name = name_] { - return Encode(xflow_id, name); - }); - } - - template - void InstantActivity() const { - return tsl::profiler::TraceMe::InstantActivity( - [xflow_id = xflow_id_, name = name_] { - return Encode(xflow_id, name); - }); - } - - private: - template - static std::string Encode(uint64_t xflow_id, absl::string_view name) { - static constexpr absl::string_view flow_dir_str = - D == kSend ? "send" : (D == kRecv ? "recv" : "recv_send"); - const XFlow flow(xflow_id, D == kRecvSend ? XFlow::kFlowInOut - : (D == kRecv ? XFlow::kFlowIn - : XFlow::kFlowOut)); - return tsl::profiler::TraceMeEncode( - name, {{"dir", flow_dir_str}, {"flow", flow.ToStatValue()}}); - }; - - const uint64_t xflow_id_; - const absl::string_view name_; -}; - // Thread-safe data structure for holding batched operations. class BatchedOps { public: @@ -421,6 +361,12 @@ void RpcHelper::Disconnect() { batcher_->Finish(absl::CancelledError("Disconnected by client")); } +uint64_t RpcHelper::NextHandle() { + uint64_t result = next_handle_.fetch_add(1, std::memory_order_relaxed); + CHECK_LT(result, kServerGeneratedHandlesMinValue); + return result; +} + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt_proxy/client/rpc_helper.h b/xla/python/ifrt_proxy/client/rpc_helper.h index fc88c22756502..38b61d83cbaa6 100644 --- a/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/xla/python/ifrt_proxy/client/rpc_helper.h @@ -17,6 +17,7 @@ #ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ #define XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ +#include #include #include #include @@ -139,7 +140,12 @@ class RpcHelper { ResponseFuture LoadedHostCallbackReturn( std::unique_ptr req); - // Utility functions for common functions. + // Utility functions. + + // Generates a handle for new arrays, array data stored in HostBufferStore, + // etc. Guarantees that the generated handle will not conflict with those + // generated at the server side by IfrtBackend. + uint64_t NextHandle(); Future<> CheckFuture(uint64_t handle); @@ -149,6 +155,8 @@ class RpcHelper { const IfrtProxyVersion version_; std::shared_ptr host_buffer_store_; + std::atomic next_handle_ = 1; + absl::Mutex mu_; uint64_t next_op_id_ ABSL_GUARDED_BY(mu_) = 1; }; diff --git a/xla/python/ifrt_proxy/client/version.h b/xla/python/ifrt_proxy/client/version.h index fcdb1202c26e1..19f51201c2ce8 100644 --- a/xla/python/ifrt_proxy/client/version.h +++ b/xla/python/ifrt_proxy/client/version.h @@ -24,7 +24,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kClientMinVersion = 3; -inline constexpr int kClientMaxVersion = 9; +inline constexpr int kClientMaxVersion = 10; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) } // namespace proxy diff --git a/xla/python/ifrt_proxy/common/BUILD b/xla/python/ifrt_proxy/common/BUILD index ad9bb54e38d6d..ff7f6d12e44d8 100644 --- a/xla/python/ifrt_proxy/common/BUILD +++ b/xla/python/ifrt_proxy/common/BUILD @@ -202,6 +202,19 @@ cc_library( ], ) +cc_library( + name = "prof_util", + hdrs = ["prof_util.h"], + deps = [ + ":ifrt_service_proto_cc", + "//xla/tsl/profiler/utils:xplane_schema", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:random", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/profiler/lib:traceme_encode", + ], +) + # copybara:uncomment_begin # bzl_library( # name = "ifrt_proxy_bzl", diff --git a/xla/python/ifrt_proxy/common/VERSION.md b/xla/python/ifrt_proxy/common/VERSION.md index b1a3b3fb88f38..a2db37a1b06e3 100644 --- a/xla/python/ifrt_proxy/common/VERSION.md +++ b/xla/python/ifrt_proxy/common/VERSION.md @@ -53,3 +53,9 @@ * Added date: 2024-10-31. * Changes: * Added support for string Arrays (i.e., arrays with dtype `DType::kString`). + +## Version 10 + +* Added date: 2024-11-08. +* Changes: + * MakeArrayFromHostBuffer uses client-manufactured array handles and sends data asynchronously. diff --git a/xla/python/ifrt_proxy/common/array_util.cc b/xla/python/ifrt_proxy/common/array_util.cc index 5433fabbf4d42..6b3bb83863d49 100644 --- a/xla/python/ifrt_proxy/common/array_util.cc +++ b/xla/python/ifrt_proxy/common/array_util.cc @@ -14,6 +14,7 @@ #include "xla/python/ifrt_proxy/common/array_util.h" +#include #include #include @@ -154,13 +155,13 @@ void* ArrayMemRegion::zeroth_element() const { return mem_region_start_; } -absl::StatusOr SerializeStringHostBuffer( +absl::StatusOr> SerializeStringHostBuffer( absl::Span cords) { proto::StringArrayContents string_array_proto; for (const auto& c : cords) { string_array_proto.add_strings(std::string(c)); } - return string_array_proto.SerializeAsString(); + return std::make_unique(string_array_proto.SerializeAsString()); } absl::StatusOr> DeserializeStringHostBufferFromString( diff --git a/xla/python/ifrt_proxy/common/array_util.h b/xla/python/ifrt_proxy/common/array_util.h index da43fb31b8a50..90b171b144c87 100644 --- a/xla/python/ifrt_proxy/common/array_util.h +++ b/xla/python/ifrt_proxy/common/array_util.h @@ -79,7 +79,7 @@ class ArrayMemRegion { // Utilities for serializing and deserializing a host buffer of dtype // `DType::kString` (represented as arrays of absl::Cords). -absl::StatusOr SerializeStringHostBuffer( +absl::StatusOr> SerializeStringHostBuffer( absl::Span cords); absl::StatusOr> DeserializeStringHostBufferFromString( diff --git a/xla/python/ifrt_proxy/common/array_util_test.cc b/xla/python/ifrt_proxy/common/array_util_test.cc index f95d65ad4addf..86e1e8485c0e1 100644 --- a/xla/python/ifrt_proxy/common/array_util_test.cc +++ b/xla/python/ifrt_proxy/common/array_util_test.cc @@ -199,7 +199,7 @@ TEST(StringHostBufferTest, SerializeDeserializeWithString) { std::vector input = {absl::Cord("foo"), absl::Cord("bar")}; TF_ASSERT_OK_AND_ASSIGN(auto serialized, SerializeStringHostBuffer(input)); TF_ASSERT_OK_AND_ASSIGN(auto deserialized, - DeserializeStringHostBufferFromString(serialized)); + DeserializeStringHostBufferFromString(*serialized)); EXPECT_EQ(deserialized, input); } @@ -210,7 +210,7 @@ TEST(StringHostBufferTest, std::vector deserialized(input.size()); ASSERT_THAT(DeserializeFromCordIntoPreallocatedStringHostBuffer( - absl::Cord(serialized), deserialized.data()), + absl::Cord(*serialized), deserialized.data()), IsOk()); EXPECT_EQ(deserialized, input); diff --git a/xla/python/ifrt_proxy/common/ifrt_service.proto b/xla/python/ifrt_proxy/common/ifrt_service.proto index a0812a521f9e3..761240311bb59 100644 --- a/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -257,6 +257,9 @@ message MakeArrayFromHostBufferRequest { ShardingProto sharding = 3; fixed64 host_buffer_handle = 4; optional proto.ByteStrides byte_strides = 5; + // If array_handle is provided, the server will either respond with the same + // handle in `MakeArrayFromHostBufferResponse` or return an error. + optional fixed64 array_handle = 6; } message MakeArrayFromHostBufferResponse { fixed64 array_handle = 1; diff --git a/xla/python/ifrt_proxy/common/prof_util.h b/xla/python/ifrt_proxy/common/prof_util.h new file mode 100644 index 0000000000000..54eef195357cb --- /dev/null +++ b/xla/python/ifrt_proxy/common/prof_util.h @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_PROF_UTIL_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_PROF_UTIL_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "tsl/platform/random.h" +#include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/lib/traceme_encode.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// XFlowHelper makes it easier to create trace spans with a flow between them. +// Typical usage: +// +// XFlowHelper flow("my_request"); +// ... +// +// auto response_handler = [flow](ResponseMsg msg) { +// flow.InstantActivity(); +// LOG(INFO) << "Received response: " << msg; +// } +// +// { +// auto request_span = flow.Span(); +// auto request_protobuf = CreateRequestProtobuf(); +// transport.Send(request_protobuf, response_handler); +// } +// +class XFlowHelper { + public: + explicit XFlowHelper(absl::string_view name) + : xflow_id_(tsl::random::New64() >> 8 /*XFlow IDs are 56 bits*/), + name_(name) {} + + typedef enum { kSend, kRecv, kRecvSend } Direction; + + template + tsl::profiler::TraceMe Span() const { + return tsl::profiler::TraceMe([xflow_id = xflow_id_, name = name_] { + return Encode(xflow_id, name); + }); + } + + template + void InstantActivity() const { + return tsl::profiler::TraceMe::InstantActivity( + [xflow_id = xflow_id_, name = name_] { + return Encode(xflow_id, name); + }); + } + + private: + template + static std::string Encode(uint64_t xflow_id, absl::string_view name) { + using XFlow = ::tsl::profiler::XFlow; + switch (D) { + case kSend: + return tsl::profiler::TraceMeEncode( + name, {{"dir", "send"}, + {"flow", XFlow(xflow_id, XFlow::kFlowOut).ToStatValue()}}); + case kRecv: + return tsl::profiler::TraceMeEncode( + name, {{"dir", "recv"}, + {"flow", XFlow(xflow_id, XFlow::kFlowIn).ToStatValue()}}); + case kRecvSend: + return tsl::profiler::TraceMeEncode( + name, {{"dir", "recv_send"}, + {"flow", XFlow(xflow_id, XFlow::kFlowInOut).ToStatValue()}}); + } + }; + + const uint64_t xflow_id_; + const absl::string_view name_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_PROF_UTIL_H_ diff --git a/xla/python/ifrt_proxy/common/types.h b/xla/python/ifrt_proxy/common/types.h index 06f6771b54f9d..3fa85454992b1 100644 --- a/xla/python/ifrt_proxy/common/types.h +++ b/xla/python/ifrt_proxy/common/types.h @@ -57,6 +57,8 @@ absl::StatusOr ToVariantProto(const xla::PjRtValueType& value); std::vector FromByteStridesProto(const proto::ByteStrides& strides); proto::ByteStrides ToByteStridesProto(absl::Span strides); +constexpr uint64_t kServerGeneratedHandlesMinValue = 1ULL << 48; + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt_proxy/integration_tests/BUILD b/xla/python/ifrt_proxy/integration_tests/BUILD index fe184323259bc..a0d91d534bb37 100644 --- a/xla/python/ifrt_proxy/integration_tests/BUILD +++ b/xla/python/ifrt_proxy/integration_tests/BUILD @@ -102,6 +102,7 @@ ifrt_proxy_cc_test( "//xla/python/pjrt_ifrt", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -110,7 +111,6 @@ ifrt_proxy_cc_test( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:env", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc index 027811d574660..fcd2355c46a18 100644 --- a/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc +++ b/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -24,6 +24,7 @@ #include #include #include "absl/base/thread_annotations.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -50,11 +51,9 @@ #include "xla/python/ifrt_proxy/server/grpc_server.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/env.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" namespace xla { namespace ifrt { @@ -80,15 +79,7 @@ class MockArrayTest : public testing::Test { CreateClient(absl::StrCat("grpc://", address))); } - struct ArrayPair { - // IFRT array exposed to the proxy's user. Not a mock. - tsl::RCReference proxy_client_array; - // IFRT array owned by the proxy server whose behavior should be - // reflected by proxy_client_array. Mock but delegated. - tsl::RCReference backend_array; - }; - - absl::StatusOr NewArray() { + absl::StatusOr> NewArray() { DType dtype(DType::kF32); Shape shape({2, 3}); auto data = std::make_unique>(6); @@ -105,21 +96,13 @@ class MockArrayTest : public testing::Test { Client::HostBufferSemantics::kImmutableOnlyDuringCall, /*on_done_with_host_buffer=*/nullptr)); - // When the above `MakeArrayFromHostBuffer` results in the server issuing a - // `MakeArrayFromHostBuffer()` to the underlying mock backend, the mock - // backend enqueues the returned mock array onto `mock_arrays_` (this code - // is in `CreateMockBackend()`). - absl::MutexLock l(&mu_); - CHECK_EQ(mock_arrays_.size(), 1); - auto mock = mock_arrays_.back(); - mock_arrays_.pop_back(); - return ArrayPair{client_arr, mock}; + return client_arr; } std::unique_ptr server_; std::unique_ptr client_; - private: + protected: absl::StatusOr> CreateMockBackend() { // TODO(b/292339723): Use reference backend as the delegate while mocking. xla::CpuClientOptions options; @@ -145,9 +128,26 @@ class MockArrayTest : public testing::Test { data, dtype, shape, byte_strides, sharding, semantics, on_done_with_host_buffer)); auto result = tsl::MakeRef(delegated); - - absl::MutexLock l(&mu_); - mock_arrays_.push_back(result); + ON_CALL(*result, GetReadyFuture) + .WillByDefault([this, delegated]() { + absl::MutexLock l(&mu_); + if (get_ready_hook_) { + absl::Status s = get_ready_hook_(); + if (!s.ok()) return Future<>(s); + } + return delegated->GetReadyFuture(); + }); + ON_CALL(*result, CopyToHostBuffer) + .WillByDefault([this, delegated](auto data, auto byte_strides, + auto semantics) { + absl::MutexLock l(&mu_); + if (copy_host_hook_) { + absl::Status s = copy_host_hook_(); + if (!s.ok()) return Future<>(s); + } + return delegated->CopyToHostBuffer(data, byte_strides, + semantics); + }); return result; }); @@ -165,20 +165,24 @@ class MockArrayTest : public testing::Test { } absl::Mutex mu_; - std::vector> mock_arrays_ ABSL_GUARDED_BY(mu_); + absl::AnyInvocable get_ready_hook_ ABSL_GUARDED_BY(mu_); + absl::AnyInvocable copy_host_hook_ ABSL_GUARDED_BY(mu_); }; TEST_F(MockArrayTest, ReadyFutureWaitsUntilReady) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + TF_ASSERT_OK_AND_ASSIGN(auto arr, NewArray()); absl::Notification wait_ready; - EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { - wait_ready.WaitForNotification(); - return arr.backend_array->delegated()->GetReadyFuture(); - }); + { + absl::MutexLock l(&mu_); + get_ready_hook_ = [&]() { + wait_ready.WaitForNotification(); + return absl::OkStatus(); + }; + } - auto ready = arr.proxy_client_array->GetReadyFuture(); + auto ready = arr->GetReadyFuture(); absl::SleepFor(kSomeTime); EXPECT_FALSE(ready.IsReady()); @@ -188,31 +192,34 @@ TEST_F(MockArrayTest, ReadyFutureWaitsUntilReady) { } TEST_F(MockArrayTest, ReadyFuturePropagatesError) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + TF_ASSERT_OK_AND_ASSIGN(auto arr, NewArray()); - EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { - return Future<>(absl::InternalError("testing")); - }); + absl::Notification wait_ready; - EXPECT_THAT(arr.proxy_client_array->GetReadyFuture().Await(), - StatusIs(kInternal)); + { + absl::MutexLock l(&mu_); + get_ready_hook_ = [&]() { return absl::InternalError("testing"); }; + } + + EXPECT_THAT(arr->GetReadyFuture().Await(), StatusIs(kInternal)); } TEST_F(MockArrayTest, CopyToHostFutureWaitsUntilCopied) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + TF_ASSERT_OK_AND_ASSIGN(auto arr, NewArray()); absl::Notification wait_ready; - EXPECT_CALL(*arr.backend_array, CopyToHostBuffer) - .WillOnce([&](auto data, auto byte_strides, auto semantics) { - wait_ready.WaitForNotification(); - return arr.backend_array->delegated()->CopyToHostBuffer( - data, byte_strides, semantics); - }); + { + absl::MutexLock l(&mu_); + copy_host_hook_ = [&]() { + wait_ready.WaitForNotification(); + return absl::OkStatus(); + }; + } char data[1000]; - auto copied = arr.proxy_client_array->CopyToHostBuffer( - data, /*byte_strides=*/std::nullopt, ArrayCopySemantics::kAlwaysCopy); + auto copied = arr->CopyToHostBuffer(data, /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy); absl::SleepFor(kSomeTime); EXPECT_FALSE(copied.IsReady()); @@ -222,17 +229,18 @@ TEST_F(MockArrayTest, CopyToHostFutureWaitsUntilCopied) { } TEST_F(MockArrayTest, CopyToHostFuturePropagatesError) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + TF_ASSERT_OK_AND_ASSIGN(auto arr, NewArray()); absl::Notification wait_ready; - EXPECT_CALL(*arr.backend_array, CopyToHostBuffer).WillOnce([&] { - return Future<>(absl::InternalError("testing")); - }); + { + absl::MutexLock l(&mu_); + copy_host_hook_ = [&]() { return absl::InternalError("testing"); }; + } char data[1000]; - auto copied = arr.proxy_client_array->CopyToHostBuffer( - data, /*byte_strides=*/std::nullopt, ArrayCopySemantics::kAlwaysCopy); + auto copied = arr->CopyToHostBuffer(data, /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy); EXPECT_THAT(copied.Await(), StatusIs(kInternal)); } diff --git a/xla/python/ifrt_proxy/server/BUILD b/xla/python/ifrt_proxy/server/BUILD index 56c75dec2a6fa..4909062d22b8a 100644 --- a/xla/python/ifrt_proxy/server/BUILD +++ b/xla/python/ifrt_proxy/server/BUILD @@ -148,6 +148,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@tsl//tsl/platform:env", @@ -235,6 +236,8 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@tsl//tsl/profiler/lib:traceme", ], ) @@ -290,8 +293,12 @@ ifrt_proxy_cc_test( srcs = ["host_buffer_test.cc"], deps = [ ":host_buffer", + "//xla/python/ifrt", "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:env", "@tsl//tsl/platform:status_matchers", ], ) diff --git a/xla/python/ifrt_proxy/server/host_buffer.cc b/xla/python/ifrt_proxy/server/host_buffer.cc index 4b9dd7391ec81..a1fe5dc276bb2 100644 --- a/xla/python/ifrt_proxy/server/host_buffer.cc +++ b/xla/python/ifrt_proxy/server/host_buffer.cc @@ -14,15 +14,19 @@ #include "xla/python/ifrt_proxy/server/host_buffer.h" +#include #include #include #include +#include "absl/base/thread_annotations.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace ifrt { @@ -30,6 +34,9 @@ namespace proxy { absl::Status HostBufferStore::Store(uint64_t handle, std::string data) { absl::MutexLock lock(&mu_); + if (shutdown_msg_.has_value()) { + return absl::CancelledError(*shutdown_msg_); + } const bool inserted = buffers_.insert({handle, std::make_shared(std::move(data))}) .second; @@ -41,8 +48,18 @@ absl::Status HostBufferStore::Store(uint64_t handle, std::string data) { } absl::StatusOr> HostBufferStore::Lookup( - uint64_t handle) { + uint64_t handle, absl::Duration timeout) { absl::MutexLock lock(&mu_); + auto cond = [&]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return shutdown_msg_.has_value() || buffers_.contains(handle); + }; + if (!cond()) { + tsl::profiler::TraceMe traceme("HostBufferStore::Lookup.Wait"); + mu_.AwaitWithTimeout(absl::Condition(&cond), timeout); + } + if (shutdown_msg_) { + return absl::CancelledError(shutdown_msg_.value()); + } const auto it = buffers_.find(handle); if (it == buffers_.end()) { return absl::NotFoundError( @@ -60,6 +77,14 @@ absl::Status HostBufferStore::Delete(uint64_t handle) { return absl::OkStatus(); } +void HostBufferStore::Shutdown(std::string reason) { + absl::MutexLock lock(&mu_); + if (!shutdown_msg_.has_value()) { + shutdown_msg_ = std::move(reason); + } + buffers_.clear(); +} + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt_proxy/server/host_buffer.h b/xla/python/ifrt_proxy/server/host_buffer.h index 0d82898be9258..a11a00f7c7def 100644 --- a/xla/python/ifrt_proxy/server/host_buffer.h +++ b/xla/python/ifrt_proxy/server/host_buffer.h @@ -17,14 +17,18 @@ #ifndef XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ #define XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ +#include #include +#include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" namespace xla { namespace ifrt { @@ -42,17 +46,25 @@ class HostBufferStore { absl::Status Store(uint64_t handle, std::string data); // Retrieves the data associated with the handle. Returns an error if the - // handle does not exist. - absl::StatusOr> Lookup(uint64_t handle); + // handle does not exist within the given timeout or if `Shutdown()` is + // called. + absl::StatusOr> Lookup( + uint64_t handle, absl::Duration timeout = absl::ZeroDuration()); // Deletes the host buffer associated with the handle. Returns an error if the // handle does not exist. absl::Status Delete(uint64_t handle); + // Deletes all handles and permanently prevents addition of any new handles. + void Shutdown(std::string reason); + + ~HostBufferStore() { Shutdown("HostBufferStore is being destroyed"); } + private: absl::Mutex mu_; absl::flat_hash_map> buffers_ ABSL_GUARDED_BY(mu_); + std::optional shutdown_msg_ ABSL_GUARDED_BY(mu_); }; } // namespace proxy diff --git a/xla/python/ifrt_proxy/server/host_buffer_test.cc b/xla/python/ifrt_proxy/server/host_buffer_test.cc index 7adc31658dda3..98ea1097a052b 100644 --- a/xla/python/ifrt_proxy/server/host_buffer_test.cc +++ b/xla/python/ifrt_proxy/server/host_buffer_test.cc @@ -15,11 +15,17 @@ #include "xla/python/ifrt_proxy/server/host_buffer.h" #include +#include #include #include #include #include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "xla/python/ifrt/future.h" +#include "tsl/platform/env.h" #include "tsl/platform/status_matchers.h" namespace xla { @@ -27,6 +33,7 @@ namespace ifrt { namespace proxy { namespace { +using ::testing::Not; using ::testing::Pointee; using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; @@ -43,6 +50,65 @@ TEST(HostBufferStoreTest, ReadAfterWrite) { EXPECT_THAT(store.Lookup(kHandle), StatusIs(absl::StatusCode::kNotFound)); } +TEST(HostBufferStoreTest, WriteAfterReadStarted) { + HostBufferStore store; + const uint64_t kHandle = 1; + + auto lookup_promise = + Future>::CreatePromise(); + Future> lookup_fut(lookup_promise); + + absl::Notification closure_started; + tsl::Env::Default()->SchedClosure([&]() { + closure_started.Notify(); + lookup_promise.Set(store.Lookup(kHandle, /*timeout=*/absl::Seconds(10))); + }); + + closure_started.WaitForNotification(); + absl::SleepFor(absl::Seconds(1)); + + ASSERT_THAT(store.Store(kHandle, "foo"), IsOk()); + EXPECT_THAT(lookup_fut.Await(), IsOkAndHolds(Pointee(std::string("foo")))); +} + +TEST(HostBufferStoreTest, ShutdownAfterReadStarted) { + HostBufferStore store; + const uint64_t kHandle = 1; + + auto lookup_promise = + Future>::CreatePromise(); + Future> lookup_fut(lookup_promise); + + absl::Notification closure_started; + tsl::Env::Default()->SchedClosure([&]() { + closure_started.Notify(); + lookup_promise.Set(store.Lookup(kHandle, /*timeout=*/absl::Seconds(10))); + }); + + closure_started.WaitForNotification(); + absl::SleepFor(absl::Seconds(1)); + + store.Shutdown("test"); + EXPECT_THAT(lookup_fut.Await(), StatusIs(Not(absl::StatusCode::kOk))); +} + +TEST(HostBufferStoreTest, WriteAfterShutdown) { + HostBufferStore store; + const uint64_t kHandle = 1; + store.Shutdown("test"); + EXPECT_THAT(store.Store(kHandle, "foo"), + StatusIs(Not(absl::StatusCode::kOk))); +} + +TEST(HostBufferStoreTest, LookupAfterShutdown) { + HostBufferStore store; + const uint64_t kHandle = 1; + ASSERT_OK(store.Store(kHandle, "foo")); + store.Shutdown("test"); + EXPECT_THAT(store.Lookup(kHandle, /*timeout=*/absl::InfiniteDuration()), + StatusIs(Not(absl::StatusCode::kOk))); +} + TEST(HostBufferStoreTest, UnknownHandle) { HostBufferStore store; const uint64_t kHandle = 1; diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc index f7550e56b3dc7..a90a0066c9542 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -39,6 +40,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "xla/layout.h" @@ -102,6 +104,95 @@ MakeStringArrayFromHostBuffer( } // namespace +class IfrtBackend::InOrderRequestsProcessor { + struct Entry { + std::unique_ptr req; + Future::Promise promise; + }; + + public: + explicit InOrderRequestsProcessor(IfrtBackend* parent) + : parent_(parent), + thread_(tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "ifrt_backend_reqs_processor", + absl::bind_front(&InOrderRequestsProcessor::Loop, this))) {} + + void Shutdown(std::string reason) { + { + absl::MutexLock l(&mu_); + if (shutdown_msg_.has_value()) { + return; + } + shutdown_msg_ = reason; + } + + LOG(INFO) << "IfrtBackend::InOrderRequestsProcessor being destroyed, " + "waiting for currently processed request to finish."; + thread_.reset(); + std::deque should_cancel; + + { + absl::MutexLock l(&mu_); + entries_.swap(should_cancel); + } + + LOG(INFO) << "IfrtBackend::InOrderRequestsProcessor being destroyed, " + "cancelling remaining requests."; + for (auto& entry : should_cancel) { + entry.promise.Set(absl::CancelledError(reason)); + } + LOG(INFO) << "IfrtBackend::InOrderRequestsProcessor has been destroyed."; + } + + Future Push(std::unique_ptr request) { + auto promise = Future::CreatePromise(); + Future result(promise); + absl::MutexLock l(&mu_); + if (shutdown_msg_.has_value()) { + promise.Set(absl::InternalError(absl::StrCat( + "InOrderRequestsProcessor already stopped: ", *shutdown_msg_))); + return result; + } + Entry entry; + entry.req = std::move(request); + entry.promise = std::move(promise); + entries_.push_back(std::move(entry)); + return result; + } + + ~InOrderRequestsProcessor() { + Shutdown("InOrderRequestsProcessor is being destroyed"); + } + + private: + std::optional Pop() { + absl::MutexLock l(&mu_); + auto cond = [&]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return shutdown_msg_.has_value() || !entries_.empty(); + }; + mu_.Await(absl::Condition(&cond)); + if (shutdown_msg_.has_value()) return std::nullopt; + auto result = std::move(entries_.front()); + entries_.pop_front(); + return result; + } + + void Loop() { + while (auto entry = Pop()) { + parent_->ProcessInternal(std::move(entry->req)) + .OnReady( + [p = std::move(entry->promise)]( + absl::StatusOr r) mutable { p.Set(std::move(r)); }); + } + } + + absl::Mutex mu_; + std::deque entries_ ABSL_GUARDED_BY(mu_); + std::optional shutdown_msg_ ABSL_GUARDED_BY(mu_); + IfrtBackend* const parent_; + std::unique_ptr thread_; +}; + IfrtBackend::IfrtBackend(IfrtProxyVersion version, uint64_t session_id, std::shared_ptr ifrt_client, std::shared_ptr host_buffer_store) @@ -120,7 +211,9 @@ IfrtBackend::IfrtBackend(IfrtProxyVersion version, uint64_t session_id, }(), "IfrtBackend", // TODO(b/282757875): Consider making this configurable. - /*num_threads=*/32) {} + /*num_threads=*/32), + in_order_requests_processor_( + std::make_unique(this)) {} absl::StatusOr> IfrtBackend::Create( IfrtProxyVersion version, uint64_t session_id, @@ -142,6 +235,9 @@ absl::StatusOr> IfrtBackend::Create( } IfrtBackend::~IfrtBackend() { + in_order_requests_processor_->Shutdown("IFRT backend has shut down"); + host_buffer_store_->Shutdown("IFRT backend has shut down"); + // Cancel all in-flight host callback executions. { absl::MutexLock lock(&host_callback_queues_mutex_); @@ -171,6 +267,11 @@ IfrtBackend::~IfrtBackend() { Future IfrtBackend::Process( std::unique_ptr request) { + return in_order_requests_processor_->Push(std::move(request)); +} + +Future IfrtBackend::ProcessInternal( + std::unique_ptr request) { switch (request->request_case()) { case IfrtRequest::RequestCase::kInitRequest: return Future(HandleInit(std::move(request))); @@ -236,15 +337,21 @@ Future IfrtBackend::Process( } } +IfrtBackend::HandleGenerator::HandleGenerator() + : current_(kServerGeneratedHandlesMinValue) {} + uint64_t IfrtBackend::HandleGenerator::New() { absl::MutexLock lock(&mu_); - return current_++; + uint64_t result = current_++; + CHECK_GE(result, kServerGeneratedHandlesMinValue); + return result; } void IfrtBackend::HandleGenerator::BulkNew(absl::Span handles) { absl::MutexLock lock(&mu_); std::iota(handles.begin(), handles.end(), current_); current_ += handles.size(); + CHECK_GE(current_, kServerGeneratedHandlesMinValue); } Future IfrtBackend::AsyncExecute( @@ -424,11 +531,7 @@ Future IfrtBackend::HandleCheckValueReadyRequest( absl::StatusOr IfrtBackend::HandleMakeArrayFromHostBufferRequest( std::unique_ptr request) { - if (!request->has_make_array_from_host_buffer_request()) { - return absl::InternalError( - "MakeArrayFromHostBuffer got an IfrtRequest with no " - "MakeArrayFromHostBufferRequest in it."); - } + CHECK(request->has_make_array_from_host_buffer_request()); auto* make_array_request = request->mutable_make_array_from_host_buffer_request(); @@ -450,8 +553,10 @@ IfrtBackend::HandleMakeArrayFromHostBufferRequest( absl::Cleanup cleanup = [&] { CHECK_OK(host_buffer_store_->Delete(host_buffer_handle)); }; - TF_ASSIGN_OR_RETURN(std::shared_ptr host_buffer, - host_buffer_store_->Lookup(host_buffer_handle)); + TF_ASSIGN_OR_RETURN( + std::shared_ptr host_buffer, + host_buffer_store_->Lookup(host_buffer_handle, + /*timeout=*/absl::InfiniteDuration())); std::move(cleanup).Invoke(); tsl::RCReference array; @@ -476,10 +581,18 @@ IfrtBackend::HandleMakeArrayFromHostBufferRequest( // TODO(b/282757875): Consider merging the handle_generator with the // arrays_. - uint64_t handle = handle_generator_.New(); + uint64_t handle = make_array_request->has_array_handle() + ? make_array_request->array_handle() + : handle_generator_.New(); { absl::MutexLock lock(&arrays_mutex_); - arrays_.insert({handle, std::move(array)}); + const bool inserted = arrays_.insert({handle, std::move(array)}).second; + if (!inserted) { + CHECK(make_array_request->has_array_handle()) << handle; + return absl::InvalidArgumentError(absl::StrCat( + "IFRT proxy: MakeArrayFromHostBuffer with client-supplied handle ", + handle, " that already exists at the server.")); + } } std::unique_ptr response = @@ -628,7 +741,7 @@ IfrtBackend::HandleCopyToStringHostBufferRequest( TF_ASSIGN_OR_RETURN(auto serialized_string_host_buffer, SerializeStringHostBuffer(*host_buffer)); TF_RETURN_IF_ERROR(host_buffer_store_->Store( - host_buffer_handle, std::move(serialized_string_host_buffer))); + host_buffer_handle, std::move(*serialized_string_host_buffer))); std::unique_ptr response = NewIfrtResponse(op_id); response->mutable_copy_to_host_buffer_response(); @@ -1404,7 +1517,8 @@ IfrtBackend::HandleLoadedHostCallbackReturnRequest( if (ret.has_result_host_buffer_handle()) { TF_ASSIGN_OR_RETURN( std::shared_ptr buffer, - host_buffer_store_->Lookup(ret.result_host_buffer_handle())); + host_buffer_store_->Lookup(ret.result_host_buffer_handle(), + /*timeout=*/absl::InfiniteDuration())); absl::Cleanup cleanup = [&] { CHECK_OK(host_buffer_store_->Delete(ret.result_host_buffer_handle())); }; diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.h b/xla/python/ifrt_proxy/server/ifrt_backend.h index 0ed8f091b2b18..02f1d8cf15be5 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.h +++ b/xla/python/ifrt_proxy/server/ifrt_backend.h @@ -78,10 +78,13 @@ class IfrtBackend final : public BackendInterface { Future Process(std::unique_ptr request) override; private: - // Generates unique handles for returning to the client. All object types - // currently use this single "handle space". + // Generates unique handles for returning to the client. Guaranteed to return + // handles that do not conflict with client-generated handles (via client-side + // RpcHelper). All object types currently use this single "handle space". class HandleGenerator { public: + HandleGenerator(); + uint64_t New(); // Bulk allocates a given number of handles and saves them into the provided @@ -90,7 +93,7 @@ class IfrtBackend final : public BackendInterface { private: absl::Mutex mu_; - uint64_t current_ ABSL_GUARDED_BY(mu_) = 1; + uint64_t current_ ABSL_GUARDED_BY(mu_); }; IfrtBackend(IfrtProxyVersion version, uint64_t session_id, @@ -105,6 +108,8 @@ class IfrtBackend final : public BackendInterface { std::function()> handle_fn, tsl::thread::ThreadPool* thread_pool = nullptr); + Future ProcessInternal(std::unique_ptr request); + ////////////////////////////////////////////////////////////////////// // Handlers for individual requests // @@ -214,6 +219,9 @@ class IfrtBackend final : public BackendInterface { // Use a separate thread pool for compilation as XLA compilation often // requires a bigger stack. tsl::thread::ThreadPool compile_thread_pool_; + + class InOrderRequestsProcessor; + std::unique_ptr in_order_requests_processor_; }; } // namespace proxy diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index e7304211faac5..5797873ace71d 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -662,12 +662,12 @@ TEST_P(IfrtBackendHandlerTest, MakeStringArrayFromHostBufferSuccess) { // Make a string host buffer. const std::vector input_strings = {absl::Cord("ab"), absl::Cord("cd")}; - TF_ASSERT_OK_AND_ASSIGN(const std::string serialized_string_buffer, + TF_ASSERT_OK_AND_ASSIGN(auto serialized_string_buffer, SerializeStringHostBuffer(input_strings)); const uint64_t kHostBufferHandle = 1234; ASSERT_THAT( - host_buffer_store_->Store(kHostBufferHandle, serialized_string_buffer), + host_buffer_store_->Store(kHostBufferHandle, *serialized_string_buffer), IsOk()); auto ifrt_request = NewIfrtRequest(NewOpId()); @@ -805,12 +805,12 @@ TEST_P(IfrtBackendHandlerTest, CopyToHostSuccessWithStringArray) { // Make a string host buffer. const std::vector input_strings = {absl::Cord("ab"), absl::Cord("cd")}; - TF_ASSERT_OK_AND_ASSIGN(const std::string serialized_string_buffer, + TF_ASSERT_OK_AND_ASSIGN(auto serialized_string_buffer, SerializeStringHostBuffer(input_strings)); const uint64_t kHostBufferHandle = 1234; ASSERT_THAT( - host_buffer_store_->Store(kHostBufferHandle, serialized_string_buffer), + host_buffer_store_->Store(kHostBufferHandle, *serialized_string_buffer), IsOk()); auto ifrt_request = NewIfrtRequest(NewOpId()); diff --git a/xla/python/ifrt_proxy/server/version.h b/xla/python/ifrt_proxy/server/version.h index 7952d34c01b82..1eb86595f5288 100644 --- a/xla/python/ifrt_proxy/server/version.h +++ b/xla/python/ifrt_proxy/server/version.h @@ -26,7 +26,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kServerMinVersion = 1; -inline constexpr int kServerMaxVersion = 9; +inline constexpr int kServerMaxVersion = 10; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) // Returns a version that both the client and the server support, or an error if