Skip to content

Commit

Permalink
Revert "refactor: Add string input checks (#136)"
Browse files Browse the repository at this point in the history
This reverts commit 0d76fbf.
  • Loading branch information
yinggeh committed Aug 24, 2024
1 parent 7425a42 commit d9feb72
Showing 1 changed file with 57 additions and 11 deletions.
68 changes: 57 additions & 11 deletions src/libtorch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,7 @@ SetStringInputTensor(
cudaStream_t stream, const char* host_policy_name)
{
bool cuda_copy = false;
size_t element_idx = 0;

// For string data type, we always need to have the data on CPU so
// that we can read string length and construct the string
Expand All @@ -1928,7 +1929,7 @@ SetStringInputTensor(
stream, &cuda_copy);
if (err != nullptr) {
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
FillStringTensor(input_list, request_element_cnt);
FillStringTensor(input_list, request_element_cnt - element_idx);
return cuda_copy;
}

Expand All @@ -1939,19 +1940,64 @@ SetStringInputTensor(
}
#endif // TRITON_ENABLE_GPU

std::vector<std::pair<const char*, const uint32_t>> str_list;
err = ValidateStringBuffer(
content, content_byte_size, request_element_cnt, name, &str_list);
// Set string values.
for (const auto& [addr, len] : str_list) {
input_list->push_back(std::string(addr, len));
// Parse content and assign to 'tensor'. Each string in 'content'
// is a 4-byte length followed by the string itself with no
// null-terminator.
while (content_byte_size >= sizeof(uint32_t)) {
if (element_idx >= request_element_cnt) {
RESPOND_AND_SET_NULL_IF_ERROR(
response,
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"unexpected number of string elements " +
std::to_string(element_idx + 1) + " for inference input '" +
name + "', expecting " + std::to_string(request_element_cnt))
.c_str()));
return cuda_copy;
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
content += sizeof(uint32_t);
content_byte_size -= sizeof(uint32_t);

if (content_byte_size < len) {
RESPOND_AND_SET_NULL_IF_ERROR(
response,
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"incomplete string data for inference input '" +
std::string(name) + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(content_byte_size) + " bytes available")
.c_str()));
FillStringTensor(input_list, request_element_cnt - element_idx);
return cuda_copy;
}

// Set string value
input_list->push_back(std::string(content, len));

content += len;
content_byte_size -= len;
element_idx++;
}

size_t element_cnt = str_list.size();
if (err != nullptr) {
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
FillStringTensor(input_list, request_element_cnt - element_cnt);
if ((*response != nullptr) && (element_idx != request_element_cnt)) {
RESPOND_AND_SET_NULL_IF_ERROR(
response, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
"expected " + std::to_string(request_element_cnt) +
" strings for inference input '" + name + "', got " +
std::to_string(element_idx))
.c_str()));
if (element_idx < request_element_cnt) {
FillStringTensor(input_list, request_element_cnt - element_idx);
}
}

return cuda_copy;
}

Expand Down

0 comments on commit d9feb72

Please sign in to comment.