From ccebefbf23440baecce9b5e52b488bc0a9a1ce2b Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 18 Jul 2024 15:50:21 -0700 Subject: [PATCH 1/3] Refactor string input checks --- src/onnxruntime.cc | 56 ++++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index 58c50c3..b41c17b 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -2485,50 +2485,28 @@ ModelInstanceState::SetStringInputBuffer( size_t element_cnt = 0; if ((*responses)[idx] != nullptr) { - size_t remaining_bytes = expected_byte_size; char* data_content = input_buffer + buffer_copy_offset; - // Continue if the remaining bytes may still contain size info - while (remaining_bytes >= sizeof(uint32_t)) { - if (element_cnt >= expected_element_cnt) { - RESPOND_AND_SET_NULL_IF_ERROR( - &((*responses)[idx]), - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unexpected number of string elements ") + - std::to_string(element_cnt + 1) + " for inference input '" + - input_name + "', expecting " + - std::to_string(expected_element_cnt)) - .c_str())); - break; - } - const uint32_t len = *(reinterpret_cast(data_content)); - remaining_bytes -= sizeof(uint32_t); - // Make first byte of size info 0, so that if there is string data - // in front of it, the data becomes valid C string. - *data_content = 0; - data_content = data_content + sizeof(uint32_t); - if (len > remaining_bytes) { - RESPOND_AND_SET_NULL_IF_ERROR( - &((*responses)[idx]), - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("incomplete string data for inference input '") + - input_name + "', expecting string of length " + - std::to_string(len) + " but only " + - std::to_string(remaining_bytes) + " bytes available") - .c_str())); - break; - } else { - string_ptrs->push_back(data_content); - element_cnt++; - data_content = data_content + len; - remaining_bytes -= len; + auto callback = [](std::vector* string_ptrs, + const size_t element_idx, const char* content, + const uint32_t len) { + // Set string value + string_ptrs->push_back(content); + }; + auto fn = std::bind( + callback, string_ptrs, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + TRITONSERVER_Error* err = ValidateStringBuffer( + data_content, expected_byte_size, expected_element_cnt, + input_name.c_str(), &element_cnt, fn, true); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(&((*responses)[idx]), err); + if (element_cnt < expected_element_cnt) { + FillStringData(string_ptrs, expected_element_cnt - element_cnt); } } } - - FillStringData(string_ptrs, expected_element_cnt - element_cnt); buffer_copy_offset += expected_byte_size; } } From 067d2efa2bb0d46feef562d25776347a050ee673 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 19 Jul 2024 15:52:14 -0700 Subject: [PATCH 2/3] Improve readability --- src/onnxruntime.cc | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index b41c17b..210e742 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -2477,6 +2477,7 @@ ModelInstanceState::SetStringInputBuffer( std::vector* responses, char* input_buffer, std::vector* string_ptrs) { + std::vector> str_list; // offset for each response size_t buffer_copy_offset = 0; for (size_t idx = 0; idx < expected_byte_sizes.size(); idx++) { @@ -2486,25 +2487,21 @@ ModelInstanceState::SetStringInputBuffer( size_t element_cnt = 0; if ((*responses)[idx] != nullptr) { char* data_content = input_buffer + buffer_copy_offset; - - auto callback = [](std::vector* string_ptrs, - const size_t element_idx, const char* content, - const uint32_t len) { - // Set string value - string_ptrs->push_back(content); - }; - auto fn = std::bind( - callback, string_ptrs, std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3); - TRITONSERVER_Error* err = ValidateStringBuffer( data_content, expected_byte_size, expected_element_cnt, - input_name.c_str(), &element_cnt, fn, true); + input_name.c_str(), &str_list); + // Set string values. + for (const auto& [addr, len] : str_list) { + // Make first byte of size info 0, so that if there is string data + // in front of it, the data becomes valid C string. + *const_cast(addr - sizeof(uint32_t)) = 0; + string_ptrs->push_back(addr); + } + str_list.clear(); + if (err != nullptr) { RESPOND_AND_SET_NULL_IF_ERROR(&((*responses)[idx]), err); - if (element_cnt < expected_element_cnt) { - FillStringData(string_ptrs, expected_element_cnt - element_cnt); - } + FillStringData(string_ptrs, expected_element_cnt - element_cnt); } } buffer_copy_offset += expected_byte_size; From 104fb78dc65e7c4649f5caf6e2acb2fbaecfb0b5 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 26 Jul 2024 11:03:33 -0700 Subject: [PATCH 3/3] Minor fix --- src/onnxruntime.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index 210e742..acdb81e 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -2484,7 +2484,6 @@ ModelInstanceState::SetStringInputBuffer( const size_t expected_byte_size = expected_byte_sizes[idx]; const size_t expected_element_cnt = expected_element_cnts[idx]; - size_t element_cnt = 0; if ((*responses)[idx] != nullptr) { char* data_content = input_buffer + buffer_copy_offset; TRITONSERVER_Error* err = ValidateStringBuffer( @@ -2497,12 +2496,13 @@ ModelInstanceState::SetStringInputBuffer( *const_cast(addr - sizeof(uint32_t)) = 0; string_ptrs->push_back(addr); } - str_list.clear(); + size_t element_cnt = str_list.size(); if (err != nullptr) { RESPOND_AND_SET_NULL_IF_ERROR(&((*responses)[idx]), err); FillStringData(string_ptrs, expected_element_cnt - element_cnt); } + str_list.clear(); } buffer_copy_offset += expected_byte_size; }