diff --git a/include/triton/backend/backend_common.h b/include/triton/backend/backend_common.h index e8ba5fa..bbc80dd 100644 --- a/include/triton/backend/backend_common.h +++ b/include/triton/backend/backend_common.h @@ -671,4 +671,21 @@ TRITONSERVER_Error* BufferAsTypedString( /// \return a formatted string for logging the request ID. std::string GetRequestId(TRITONBACKEND_Request* request); +/// Validate the contiguous string buffer with correct format +/// ... and parse string +/// elements into list of pairs of memory address and length. +/// Note the returned list of pairs points to valid memory as long +/// as memory pointed by buffer remains allocated. +/// +/// \param buffer The pointer to the contiguous string buffer. +/// \param buffer_byte_size The size of the buffer in bytes. +/// \param expected_element_cnt The number of expected string elements. +/// \param input_name The name of the input buffer. +/// \param str_list Returns pairs of address and length of parsed strings. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ValidateStringBuffer( + const char* buffer, size_t buffer_byte_size, + const size_t expected_element_cnt, const char* input_name, + std::vector>* str_list); + }} // namespace triton::backend diff --git a/src/backend_common.cc b/src/backend_common.cc index 8c8821d..fb705aa 100644 --- a/src/backend_common.cc +++ b/src/backend_common.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -1372,4 +1372,64 @@ GetRequestId(TRITONBACKEND_Request* request) return std::string("[request id: ") + request_id + "] "; } +TRITONSERVER_Error* +ValidateStringBuffer( + const char* buffer, size_t buffer_byte_size, + const size_t expected_element_cnt, const char* input_name, + std::vector>* str_list) +{ + size_t element_idx = 0; + size_t remaining_bytes = buffer_byte_size; + + // Each string in 'buffer' is a 4-byte length followed by the string itself + // with no null-terminator. + while (remaining_bytes >= sizeof(uint32_t)) { + // Do not modify this line. str_list->size() must not exceed + // expected_element_cnt. + if (element_idx >= expected_element_cnt) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "unexpected number of string elements " + + std::to_string(element_idx + 1) + " for inference input '" + + input_name + "', expecting " + + std::to_string(expected_element_cnt)) + .c_str()); + } + + const uint32_t len = *(reinterpret_cast(buffer)); + remaining_bytes -= sizeof(uint32_t); + buffer += sizeof(uint32_t); + + if (remaining_bytes < len) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "incomplete string data for inference input '" + + std::string(input_name) + "', expecting string of length " + + std::to_string(len) + " but only " + + std::to_string(remaining_bytes) + " bytes available") + .c_str()); + } + + if (str_list) { + str_list->push_back({buffer, len}); + } + buffer += len; + remaining_bytes -= len; + element_idx++; + } + + if (element_idx != expected_element_cnt) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "expected " + std::to_string(expected_element_cnt) + + " strings for inference input '" + input_name + "', got " + + std::to_string(element_idx)) + .c_str()); + } + return nullptr; +} + }} // namespace triton::backend