Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Sep 17, 2024
1 parent 3e9dcc5 commit 23f3961
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 56 deletions.
14 changes: 9 additions & 5 deletions src/ipc_message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ IPCMessage::Create(
}

std::unique_ptr<IPCMessage>
IPCMessage::Create(IPCMessageShm* ipc_message_shm,
bi::managed_external_buffer::handle_t& message_handle)
IPCMessage::Create(
IPCMessageShm* ipc_message_shm,
bi::managed_external_buffer::handle_t& message_handle)
{
return std::unique_ptr<IPCMessage>(new IPCMessage(ipc_message_shm, message_handle));
return std::unique_ptr<IPCMessage>(
new IPCMessage(ipc_message_shm, message_handle));
}

AllocatedSharedMemory<IPCMessageShm>&
AllocatedSharedMemory<IPCMessageShm>&
IPCMessage::GetAllocatedSharedMemory()
{
return ipc_message_shm_;
Expand Down Expand Up @@ -146,7 +148,9 @@ IPCMessage::IPCMessage(
ipc_message_handle_ = ipc_message_shm_.handle_;
}

IPCMessage::IPCMessage(IPCMessageShm* ipc_message_shm, bi::managed_external_buffer::handle_t& handle)
IPCMessage::IPCMessage(
IPCMessageShm* ipc_message_shm,
bi::managed_external_buffer::handle_t& handle)
{
ipc_message_handle_ = handle;
ipc_message_shm_ptr_ = ipc_message_shm;
Expand Down
8 changes: 5 additions & 3 deletions src/ipc_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ class IPCMessage {
const std::unique_ptr<SharedMemoryManager>& shm_pool,
bool inline_response);

static std::unique_ptr<IPCMessage>
Create(IPCMessageShm* ipc_message_shm,
static std::unique_ptr<IPCMessage> Create(
IPCMessageShm* ipc_message_shm,
bi::managed_external_buffer::handle_t& message_handle);
static std::unique_ptr<IPCMessage> LoadFromSharedMemory(
std::unique_ptr<SharedMemoryManager>& shm_pool,
Expand Down Expand Up @@ -135,7 +135,9 @@ class IPCMessage {
AllocatedSharedMemory<bi::interprocess_mutex>& response_mutex_shm,
AllocatedSharedMemory<bi::interprocess_condition>& response_cond_shm);

IPCMessage(IPCMessageShm* ipc_message_shm, bi::managed_external_buffer::handle_t& handle);
IPCMessage(
IPCMessageShm* ipc_message_shm,
bi::managed_external_buffer::handle_t& handle);
};

}}}; // namespace triton::backend::python
78 changes: 46 additions & 32 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
py::list py_request_list =
LoadRequestsFromSharedMemory(request_batch_shm_ptr);
std::unique_ptr<IPCMessage> execute_response;
// IPCMessage::Create(shm_pool_, false /* Inline response */);

std::optional<AllocatedSharedMemory<char>> response_batch;
bool has_exception = false;
Expand All @@ -675,8 +674,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
{
NVTX_RANGE(nvtx_, "PyExecute " + name_);

execute_return =
model_instance_.attr("execute")(py_request_list);
execute_return = model_instance_.attr("execute")(py_request_list);

bool is_coroutine = py::module::import("asyncio")
.attr("iscoroutine")(execute_return)
Expand All @@ -688,10 +686,12 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
} else {
py::object coroutine_return =
RunCoroutine(execute_return, false /* in_background */);
ProcessReturnedResponses(py_request_list, coroutine_return, response_batch);
ProcessReturnedResponses(
py_request_list, coroutine_return, response_batch);
}
} else {
ProcessReturnedResponses(py_request_list, execute_return, response_batch);
ProcessReturnedResponses(
py_request_list, execute_return, response_batch);
}
}
}
Expand All @@ -712,11 +712,14 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
error_string;
LOG_ERROR << err_message.c_str();
if (!response_batch) {
response_batch = shm_pool_->Construct<char>(sizeof(ResponseBatch) + sizeof(IPCMessageShm));
}
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch = shm_pool_->Construct<char>(
sizeof(ResponseBatch) + sizeof(IPCMessageShm));
}
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));

response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get());
response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get());
response_batch_shm_ptr->has_error = true;
error_string_shm = PbString::Create(shm_pool_, err_message);
response_batch_shm_ptr->error = error_string_shm->ShmHandle();
Expand All @@ -732,14 +735,19 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
}

if (!response_batch) {
response_batch = shm_pool_->Construct<char>(sizeof(ResponseBatch) + sizeof(IPCMessageShm));
ResponseBatch* response_batch_shm_ptr =reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->batch_size = 0;
}
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch = shm_pool_->Construct<char>(
sizeof(ResponseBatch) + sizeof(IPCMessageShm));
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->batch_size = 0;
}
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->has_error = false;
response_batch_shm_ptr->is_error_set = false;
execute_response = IPCMessage::Create(reinterpret_cast<IPCMessageShm*>(response_batch.value().data_.get()), response_batch.value().handle_);
execute_response = IPCMessage::Create(
reinterpret_cast<IPCMessageShm*>(response_batch.value().data_.get()),
response_batch.value().handle_);
execute_response->Args() = response_batch.value().handle_;
execute_response->InlineResponse() = false;
execute_response->Command() = PYTHONSTUB_ExecuteResponse;
Expand All @@ -761,7 +769,8 @@ Stub::ProcessResponse(InferResponse* response)

void
Stub::ProcessReturnedResponses(
py::list py_requests, py::object py_responses_obj, std::optional<AllocatedSharedMemory<char>>& response_batch)
py::list py_requests, py::object py_responses_obj,
std::optional<AllocatedSharedMemory<char>>& response_batch)
{
// Return if there is nothing to process.
if (py::isinstance<py::none>(py_responses_obj)) {
Expand Down Expand Up @@ -812,29 +821,34 @@ Stub::ProcessReturnedResponses(

std::shared_ptr<InferResponse> response =
py_responses[i].cast<std::shared_ptr<InferResponse>>();
request->GetResponseSender()->UpdateStateAndCounters(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
request->GetResponseSender()->UpdateStateAndCounters(
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
}
}
response_batch = std::move(shm_pool_->Construct<char>(sizeof(IPCMessageShm) +
// Return all the created responses using response_batch. The reason
// that both of the paths are available is that sending the responses
// using response_batch is faster than using `response_sender`.
response_batch = std::move(shm_pool_->Construct<char>(
sizeof(IPCMessageShm) +
requests_size * sizeof(bi::managed_external_buffer::handle_t) +
sizeof(ResponseBatch)));
ResponseBatch* response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));

bi::managed_external_buffer::handle_t* responses_shm_handle =
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
response_batch.value().data_.get() + sizeof(ResponseBatch) + sizeof(IPCMessageShm));

for (size_t i = 0; i < responses_size; i++) {
// Check the return type of execute function.
InferRequest* infer_request = py_requests[i].cast<InferRequest*>();
InferResponse* infer_response = py_responses[i].cast<InferResponse*>();
infer_response->PruneOutputTensors(
infer_request->RequestedOutputNames());
ProcessResponse(infer_response);
responses_shm_handle[i] = infer_response->ShmHandle();
}
response_batch_shm_ptr->batch_size = requests_size;
response_batch.value().data_.get() + sizeof(ResponseBatch) +
sizeof(IPCMessageShm));

for (size_t i = 0; i < responses_size; i++) {
// Check the return type of execute function.
InferRequest* infer_request = py_requests[i].cast<InferRequest*>();
InferResponse* infer_response = py_responses[i].cast<InferResponse*>();
infer_response->PruneOutputTensors(infer_request->RequestedOutputNames());
ProcessResponse(infer_response);
responses_shm_handle[i] = infer_response->ShmHandle();
}
response_batch_shm_ptr->batch_size = requests_size;
}

py::object
Expand Down
3 changes: 2 additions & 1 deletion src/pb_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ class Stub {
void ProcessRequests(RequestBatch* request_batch_shm_ptr);

void ProcessReturnedResponses(
py::list py_requests, py::object py_responses_obj, std::optional<AllocatedSharedMemory<char>>& response_batch);
py::list py_requests, py::object py_responses_obj,
std::optional<AllocatedSharedMemory<char>>& response_batch);

void ProcessResponse(InferResponse* response);

Expand Down
21 changes: 12 additions & 9 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ ModelInstanceState::SendMessageAndReceiveResponse(
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
TRITONBACKEND_Request** requests, const uint32_t request_count)
{
SendMessageToStub(message);
SendMessageToStub(message);

bi::managed_external_buffer::handle_t response_message;
auto error = Stub()->ReceiveMessageFromStub(response_message);
Expand Down Expand Up @@ -1224,7 +1224,8 @@ ModelInstanceState::ResponseSendDecoupled(
if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
lresponse_factory(
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
}
}

Expand Down Expand Up @@ -1280,12 +1281,15 @@ ModelInstanceState::ProcessRequests(
Stub()->StubMessageQueue()->Push(ipc_message->ShmHandle());
bi::managed_external_buffer::handle_t response_message;
Stub()->ReceiveMessageFromStub(response_message);
response = IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message);
response =
IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message);
}
char* ipc_message_shm = reinterpret_cast<char*>(response->GetAllocatedSharedMemory().data_.get());;
char* ipc_message_shm =
reinterpret_cast<char*>(response->GetAllocatedSharedMemory().data_.get());
;
ResponseBatch* response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(ipc_message_shm + sizeof(IPCMessageShm));

uint64_t compute_end_ns = 0;
SET_TIMESTAMP(compute_end_ns);
reporter.SetComputeEndNs(compute_end_ns);
Expand All @@ -1304,10 +1308,10 @@ ModelInstanceState::ProcessRequests(
}

if (response_batch_shm_ptr->batch_size > 0) {
std::shared_ptr<std::vector<TRITONBACKEND_Response*>> responses(
new std::vector<TRITONBACKEND_Response*>());
std::shared_ptr<std::vector<TRITONBACKEND_Response*>> responses(
new std::vector<TRITONBACKEND_Response*>());
responses->reserve(request_count);
for (size_t i = 0; i < request_count; i++) {
for (size_t i = 0; i < request_count; i++) {
TRITONBACKEND_Response* response;
auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
if (err == nullptr) {
Expand All @@ -1324,7 +1328,6 @@ ModelInstanceState::ProcessRequests(

// If the output provided by the model is in GPU, we will pass the list of
// buffers provided by Triton to the stub process.
// bool has_gpu_output = false;
std::vector<bool> requires_deferred_callback;

bool has_gpu_output = false;
Expand Down
9 changes: 4 additions & 5 deletions src/python_be.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,11 @@ class ModelInstanceState : public BackendModelInstance {
TRITONBACKEND_Request** requests, const uint32_t request_count);

void RespondErrorToAllRequests(
const char* message,
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
TRITONBACKEND_Request** requests, const uint32_t request_count);
const char* message,
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
TRITONBACKEND_Request** requests, const uint32_t request_count);

void SendMessageToStub(
bi::managed_external_buffer::handle_t message);
void SendMessageToStub(bi::managed_external_buffer::handle_t message);

// Model instance stub
std::unique_ptr<StubLauncher>& Stub() { return model_instance_stub_; }
Expand Down
1 change: 0 additions & 1 deletion src/response_sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ ResponseSender::Send(
"An error occurred while sending a response.");
}
}

}

bool
Expand Down

0 comments on commit 23f3961

Please sign in to comment.