From 23f3961be010f6df2593c3f1810900a66249377a Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Tue, 17 Sep 2024 15:34:31 +0000 Subject: [PATCH] Review comments --- src/ipc_message.cc | 14 +++++--- src/ipc_message.h | 8 +++-- src/pb_stub.cc | 78 +++++++++++++++++++++++++----------------- src/pb_stub.h | 3 +- src/python_be.cc | 21 +++++++----- src/python_be.h | 9 +++-- src/response_sender.cc | 1 - 7 files changed, 78 insertions(+), 56 deletions(-) diff --git a/src/ipc_message.cc b/src/ipc_message.cc index 1b813214..2fa13ba3 100644 --- a/src/ipc_message.cc +++ b/src/ipc_message.cc @@ -57,13 +57,15 @@ IPCMessage::Create( } std::unique_ptr -IPCMessage::Create(IPCMessageShm* ipc_message_shm, - bi::managed_external_buffer::handle_t& message_handle) +IPCMessage::Create( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& message_handle) { - return std::unique_ptr(new IPCMessage(ipc_message_shm, message_handle)); + return std::unique_ptr( + new IPCMessage(ipc_message_shm, message_handle)); } - AllocatedSharedMemory& +AllocatedSharedMemory& IPCMessage::GetAllocatedSharedMemory() { return ipc_message_shm_; @@ -146,7 +148,9 @@ IPCMessage::IPCMessage( ipc_message_handle_ = ipc_message_shm_.handle_; } -IPCMessage::IPCMessage(IPCMessageShm* ipc_message_shm, bi::managed_external_buffer::handle_t& handle) +IPCMessage::IPCMessage( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& handle) { ipc_message_handle_ = handle; ipc_message_shm_ptr_ = ipc_message_shm; diff --git a/src/ipc_message.h b/src/ipc_message.h index c7d0ae9d..c3d1472e 100644 --- a/src/ipc_message.h +++ b/src/ipc_message.h @@ -98,8 +98,8 @@ class IPCMessage { const std::unique_ptr& shm_pool, bool inline_response); - static std::unique_ptr - Create(IPCMessageShm* ipc_message_shm, + static std::unique_ptr Create( + IPCMessageShm* ipc_message_shm, bi::managed_external_buffer::handle_t& message_handle); static std::unique_ptr LoadFromSharedMemory( std::unique_ptr& shm_pool, @@ -135,7 +135,9 @@ class IPCMessage { AllocatedSharedMemory& response_mutex_shm, AllocatedSharedMemory& response_cond_shm); - IPCMessage(IPCMessageShm* ipc_message_shm, bi::managed_external_buffer::handle_t& handle); + IPCMessage( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& handle); }; }}}; // namespace triton::backend::python diff --git a/src/pb_stub.cc b/src/pb_stub.cc index e6c93214..4b7bffc1 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -654,7 +654,6 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) py::list py_request_list = LoadRequestsFromSharedMemory(request_batch_shm_ptr); std::unique_ptr execute_response; - // IPCMessage::Create(shm_pool_, false /* Inline response */); std::optional> response_batch; bool has_exception = false; @@ -675,8 +674,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) { NVTX_RANGE(nvtx_, "PyExecute " + name_); - execute_return = - model_instance_.attr("execute")(py_request_list); + execute_return = model_instance_.attr("execute")(py_request_list); bool is_coroutine = py::module::import("asyncio") .attr("iscoroutine")(execute_return) @@ -688,10 +686,12 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) } else { py::object coroutine_return = RunCoroutine(execute_return, false /* in_background */); - ProcessReturnedResponses(py_request_list, coroutine_return, response_batch); + ProcessReturnedResponses( + py_request_list, coroutine_return, response_batch); } } else { - ProcessReturnedResponses(py_request_list, execute_return, response_batch); + ProcessReturnedResponses( + py_request_list, execute_return, response_batch); } } } @@ -712,11 +712,14 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) error_string; LOG_ERROR << err_message.c_str(); if (!response_batch) { - response_batch = shm_pool_->Construct(sizeof(ResponseBatch) + sizeof(IPCMessageShm)); - } - ResponseBatch* response_batch_shm_ptr = reinterpret_cast(response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + } + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); - response_batch_shm_ptr = reinterpret_cast(response_batch.value().data_.get()); + response_batch_shm_ptr = + reinterpret_cast(response_batch.value().data_.get()); response_batch_shm_ptr->has_error = true; error_string_shm = PbString::Create(shm_pool_, err_message); response_batch_shm_ptr->error = error_string_shm->ShmHandle(); @@ -732,14 +735,19 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) } if (!response_batch) { - response_batch = shm_pool_->Construct(sizeof(ResponseBatch) + sizeof(IPCMessageShm)); - ResponseBatch* response_batch_shm_ptr =reinterpret_cast(response_batch.value().data_.get() + sizeof(IPCMessageShm)); - response_batch_shm_ptr->batch_size = 0; - } - ResponseBatch* response_batch_shm_ptr = reinterpret_cast(response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->batch_size = 0; + } + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); response_batch_shm_ptr->has_error = false; response_batch_shm_ptr->is_error_set = false; - execute_response = IPCMessage::Create(reinterpret_cast(response_batch.value().data_.get()), response_batch.value().handle_); + execute_response = IPCMessage::Create( + reinterpret_cast(response_batch.value().data_.get()), + response_batch.value().handle_); execute_response->Args() = response_batch.value().handle_; execute_response->InlineResponse() = false; execute_response->Command() = PYTHONSTUB_ExecuteResponse; @@ -761,7 +769,8 @@ Stub::ProcessResponse(InferResponse* response) void Stub::ProcessReturnedResponses( - py::list py_requests, py::object py_responses_obj, std::optional>& response_batch) + py::list py_requests, py::object py_responses_obj, + std::optional>& response_batch) { // Return if there is nothing to process. if (py::isinstance(py_responses_obj)) { @@ -812,29 +821,34 @@ Stub::ProcessReturnedResponses( std::shared_ptr response = py_responses[i].cast>(); - request->GetResponseSender()->UpdateStateAndCounters(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL); + request->GetResponseSender()->UpdateStateAndCounters( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL); } } - response_batch = std::move(shm_pool_->Construct(sizeof(IPCMessageShm) + + // Return all the created responses using response_batch. The reason + // that both of the paths are available is that sending the responses + // using response_batch is faster than using `response_sender`. + response_batch = std::move(shm_pool_->Construct( + sizeof(IPCMessageShm) + requests_size * sizeof(bi::managed_external_buffer::handle_t) + sizeof(ResponseBatch))); - ResponseBatch* response_batch_shm_ptr = - reinterpret_cast(response_batch.value().data_.get() + sizeof(IPCMessageShm)); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); bi::managed_external_buffer::handle_t* responses_shm_handle = reinterpret_cast( - response_batch.value().data_.get() + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); - - for (size_t i = 0; i < responses_size; i++) { - // Check the return type of execute function. - InferRequest* infer_request = py_requests[i].cast(); - InferResponse* infer_response = py_responses[i].cast(); - infer_response->PruneOutputTensors( - infer_request->RequestedOutputNames()); - ProcessResponse(infer_response); - responses_shm_handle[i] = infer_response->ShmHandle(); - } - response_batch_shm_ptr->batch_size = requests_size; + response_batch.value().data_.get() + sizeof(ResponseBatch) + + sizeof(IPCMessageShm)); + + for (size_t i = 0; i < responses_size; i++) { + // Check the return type of execute function. + InferRequest* infer_request = py_requests[i].cast(); + InferResponse* infer_response = py_responses[i].cast(); + infer_response->PruneOutputTensors(infer_request->RequestedOutputNames()); + ProcessResponse(infer_response); + responses_shm_handle[i] = infer_response->ShmHandle(); + } + response_batch_shm_ptr->batch_size = requests_size; } py::object diff --git a/src/pb_stub.h b/src/pb_stub.h index 85a2783a..7d76ec9a 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -254,7 +254,8 @@ class Stub { void ProcessRequests(RequestBatch* request_batch_shm_ptr); void ProcessReturnedResponses( - py::list py_requests, py::object py_responses_obj, std::optional>& response_batch); + py::list py_requests, py::object py_responses_obj, + std::optional>& response_batch); void ProcessResponse(InferResponse* response); diff --git a/src/python_be.cc b/src/python_be.cc index 1c6c6505..a1114efe 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1023,7 +1023,7 @@ ModelInstanceState::SendMessageAndReceiveResponse( std::shared_ptr>& responses, TRITONBACKEND_Request** requests, const uint32_t request_count) { - SendMessageToStub(message); + SendMessageToStub(message); bi::managed_external_buffer::handle_t response_message; auto error = Stub()->ReceiveMessageFromStub(response_message); @@ -1224,7 +1224,8 @@ ModelInstanceState::ResponseSendDecoupled( if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { std::unique_ptr< TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> - lresponse_factory(reinterpret_cast(response_factory)); + lresponse_factory( + reinterpret_cast(response_factory)); } } @@ -1280,12 +1281,15 @@ ModelInstanceState::ProcessRequests( Stub()->StubMessageQueue()->Push(ipc_message->ShmHandle()); bi::managed_external_buffer::handle_t response_message; Stub()->ReceiveMessageFromStub(response_message); - response = IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message); + response = + IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message); } - char* ipc_message_shm = reinterpret_cast(response->GetAllocatedSharedMemory().data_.get());; + char* ipc_message_shm = + reinterpret_cast(response->GetAllocatedSharedMemory().data_.get()); + ; ResponseBatch* response_batch_shm_ptr = reinterpret_cast(ipc_message_shm + sizeof(IPCMessageShm)); - + uint64_t compute_end_ns = 0; SET_TIMESTAMP(compute_end_ns); reporter.SetComputeEndNs(compute_end_ns); @@ -1304,10 +1308,10 @@ ModelInstanceState::ProcessRequests( } if (response_batch_shm_ptr->batch_size > 0) { - std::shared_ptr> responses( - new std::vector()); + std::shared_ptr> responses( + new std::vector()); responses->reserve(request_count); - for (size_t i = 0; i < request_count; i++) { + for (size_t i = 0; i < request_count; i++) { TRITONBACKEND_Response* response; auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); if (err == nullptr) { @@ -1324,7 +1328,6 @@ ModelInstanceState::ProcessRequests( // If the output provided by the model is in GPU, we will pass the list of // buffers provided by Triton to the stub process. - // bool has_gpu_output = false; std::vector requires_deferred_callback; bool has_gpu_output = false; diff --git a/src/python_be.h b/src/python_be.h index 4608298e..34871ea5 100644 --- a/src/python_be.h +++ b/src/python_be.h @@ -365,12 +365,11 @@ class ModelInstanceState : public BackendModelInstance { TRITONBACKEND_Request** requests, const uint32_t request_count); void RespondErrorToAllRequests( - const char* message, - std::shared_ptr>& responses, - TRITONBACKEND_Request** requests, const uint32_t request_count); + const char* message, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count); - void SendMessageToStub( - bi::managed_external_buffer::handle_t message); + void SendMessageToStub(bi::managed_external_buffer::handle_t message); // Model instance stub std::unique_ptr& Stub() { return model_instance_stub_; } diff --git a/src/response_sender.cc b/src/response_sender.cc index 043ef41d..7df90ec2 100644 --- a/src/response_sender.cc +++ b/src/response_sender.cc @@ -250,7 +250,6 @@ ResponseSender::Send( "An error occurred while sending a response."); } } - } bool