Skip to content

Commit

Permalink
Refactor string input checks
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Jul 18, 2024
1 parent c852a5e commit ea51c02
Showing 1 changed file with 9 additions and 51 deletions.
60 changes: 9 additions & 51 deletions src/libtorch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1937,64 +1937,22 @@ SetStringInputTensor(
}
#endif // TRITON_ENABLE_GPU

// Parse content and assign to 'tensor'. Each string in 'content'
// is a 4-byte length followed by the string itself with no
// null-terminator.
while (content_byte_size >= sizeof(uint32_t)) {
if (element_idx >= request_element_cnt) {
RESPOND_AND_SET_NULL_IF_ERROR(
response,
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"unexpected number of string elements " +
std::to_string(element_idx + 1) + " for inference input '" +
name + "', expecting " + std::to_string(request_element_cnt))
.c_str()));
return cuda_copy;
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
content += sizeof(uint32_t);
content_byte_size -= sizeof(uint32_t);

if (content_byte_size < len) {
RESPOND_AND_SET_NULL_IF_ERROR(
response,
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"incomplete string data for inference input '" +
std::string(name) + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(content_byte_size) + " bytes available")
.c_str()));
FillStringTensor(input_list, request_element_cnt - element_idx);
return cuda_copy;
}

auto callback = [](torch::List<std::string>* input_list, const char* content,
const uint32_t len) {
// Set string value
input_list->push_back(std::string(content, len));
};
auto fn = std::bind(
callback, input_list, std::placeholders::_2, std::placeholders::_3);

content += len;
content_byte_size -= len;
element_idx++;
}

if ((*response != nullptr) && (element_idx != request_element_cnt)) {
RESPOND_AND_SET_NULL_IF_ERROR(
response, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
"expected " + std::to_string(request_element_cnt) +
" strings for inference input '" + name + "', got " +
std::to_string(element_idx))
.c_str()));
err = ValidateStringBuffer(
content, content_byte_size, request_element_cnt, name, &element_idx, fn);
if (err != nullptr) {
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
if (element_idx < request_element_cnt) {
FillStringTensor(input_list, request_element_cnt - element_idx);
}
}

return cuda_copy;
}

Expand Down

0 comments on commit ea51c02

Please sign in to comment.