diff --git a/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.cc b/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.cc index 90fed95650a..062f77c0243 100644 --- a/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.cc +++ b/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.cc @@ -66,18 +66,33 @@ absl::Status TfLiteInterpreterWrapper::SetStringData( // [4] offset of each string (int32_t) // [sizeof(int32_t) * (num_strings + 1)]] total size of strings // [sizeof(int32_t) * (num_strings + 2)] batch.data() - int32_t num_strings = batch_size; - offset_.clear(); + (void)batch_size; + std::vector offsets; size_t total_size = 0; - offset_.push_back(static_cast(total_size)); + offsets.push_back(total_size); for (const auto& tensor : tensors) { const auto& flat = tensor->flat(); for (int i = 0; i < flat.size(); ++i) { + if (flat(i).size() > std::numeric_limits::max() - total_size) { + return absl::InternalError("String input is too large."); + } total_size += flat(i).size(); - offset_.push_back(static_cast(total_size)); + offsets.push_back(total_size); } } - size_t required_bytes = total_size + sizeof(int32_t) * (num_strings + 2); + const size_t num_strings = offsets.size() - 1; + if (num_strings > std::numeric_limits::max()) { + return absl::InternalError("Too many string inputs."); + } + const size_t header_entries = num_strings + 2; + if (header_entries > std::numeric_limits::max() / sizeof(int32_t)) { + return absl::InternalError("String input header is too large."); + } + const size_t header_bytes = sizeof(int32_t) * header_entries; + if (total_size > std::numeric_limits::max() - header_bytes) { + return absl::InternalError("String input buffer is too large."); + } + size_t required_bytes = total_size + header_bytes; if (tensor_buffer_.find(tensor_index) == tensor_buffer_.end()) { return absl::InternalError( absl::StrCat("Tensor input for index not found: ", tensor_index)); @@ -87,13 +102,18 @@ absl::Status TfLiteInterpreterWrapper::SetStringData( free(tflite_tensor->data.raw); } tflite_tensor->data.raw = reinterpret_cast(malloc(required_bytes)); + if (tflite_tensor->data.raw == nullptr) { + return absl::ResourceExhaustedError("Failed to allocate string input."); + } tensor_buffer_max_bytes_[tensor_index] = required_bytes; } tensor_buffer_[tensor_index].reset(tflite_tensor->data.raw); - memcpy(tensor_buffer_[tensor_index].get(), &num_strings, sizeof(int32_t)); - int32_t start = sizeof(int32_t) * (num_strings + 2); - for (size_t i = 0; i < offset_.size(); i++) { - size_t size_offset_i = start + offset_[i]; + const int32_t num_strings_i32 = static_cast(num_strings); + memcpy(tensor_buffer_[tensor_index].get(), &num_strings_i32, + sizeof(int32_t)); + size_t start = header_bytes; + for (size_t i = 0; i < offsets.size(); i++) { + size_t size_offset_i = start + offsets[i]; if (size_offset_i > std::numeric_limits::max()) { return absl::InternalError( absl::StrCat("Invalid size, string input too large:", size_offset_i)); diff --git a/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h b/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h index 884b47da7d8..17fc9e68b92 100644 --- a/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h +++ b/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h @@ -105,7 +105,6 @@ class TfLiteInterpreterWrapper { int batch_size_ = 1; std::map> tensor_buffer_; std::map tensor_buffer_max_bytes_; - std::vector offset_; #ifdef TFLITE_PROFILE int max_num_entries_; tflite::profiling::ProfileSummarizer run_summarizer_; diff --git a/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool_test.cc b/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool_test.cc index 3fa5d7c4be9..40c6992f99b 100644 --- a/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool_test.cc +++ b/tensorflow_serving/servables/tensorflow/tflite_interpreter_pool_test.cc @@ -178,6 +178,43 @@ TEST(TfLiteInterpreterWrapper, TfLiteInterpreterWrapperTest) { ::testing::ElementsAreArray(expected_strs)); } +TEST(TfLiteInterpreterWrapper, SetStringDataUsesFlattenedStringCount) { + std::string model_bytes; + TF_ASSERT_OK(ReadFileToString(Env::Default(), + test_util::TestSrcDirPath(kParseExampleModel), + &model_bytes)); + auto model = tflite::FlatBufferModel::BuildFromModel( + flatbuffers::GetRoot(model_bytes.data())); + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::ops::custom::AddParseExampleOp(&resolver); + std::unique_ptr interpreter; + ASSERT_EQ(tflite::InterpreterBuilder(*model, resolver)(&interpreter, + /*num_threads=*/1), + kTfLiteOk); + ASSERT_EQ(interpreter->inputs().size(), 1); + const int idx = interpreter->inputs()[0]; + auto* tensor = interpreter->tensor(idx); + ASSERT_EQ(tensor->type, kTfLiteString); + ASSERT_EQ(interpreter->ResizeInputTensor(idx, {1}), kTfLiteOk); + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + + auto interpreter_wrapper = + std::make_unique(std::move(interpreter)); + + Tensor input(DT_STRING, TensorShape({1, 2})); + input.flat()(0) = "first"; + input.flat()(1) = "second"; + std::vector data = {&input}; + + auto* wrapped = interpreter_wrapper->Get(); + tensor = wrapped->tensor(idx); + TF_ASSERT_OK(interpreter_wrapper->SetStringData(data, tensor, idx, + input.dim_size(0))); + + const auto strings = ExtractVector(wrapped->tensor(idx)); + EXPECT_THAT(strings, ::testing::ElementsAre("first", "second")); +} + } // namespace internal } // namespace serving } // namespace tensorflow