diff --git a/tensorflow_text/core/kernels/sentencepiece/BUILD b/tensorflow_text/core/kernels/sentencepiece/BUILD index 78e1eb8fd..20e5c8878 100644 --- a/tensorflow_text/core/kernels/sentencepiece/BUILD +++ b/tensorflow_text/core/kernels/sentencepiece/BUILD @@ -140,7 +140,6 @@ cc_library( "optimized_encoder.h", ], deps = [ - ":config", ":double_array_trie", ":encoder_config", ], @@ -198,7 +197,7 @@ tf_cc_library( deps = [ ":optimized_decoder", ":sentencepiece_detokenizer_h", - # tf/protobuf:error_codes_proto_impl_cc tensorflow dep, + "@com_google_absl//absl/status", ], ) diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc b/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc index 9602684d0..fee0a6ef2 100644 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc +++ b/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc @@ -95,6 +95,9 @@ std::tuple find_replacement( const flatbuffers::Vector& replacements) { const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); if (!max_match.empty()) { + if (max_match.id < 0 || max_match.id >= replacements.size()) { + return std::make_tuple(0, utils::string_view(nullptr, 0)); + } // Because flatbuffer byte is signed char which is not the same as char, // there is the reinterpret_cast here. const char* replaced_string_ptr = @@ -195,6 +198,9 @@ EncoderResult EncodeNormalizedString(const std::string& str, } auto lattice_update = [&lattice, i, piece_scores](const DoubleArrayTrie::Match& m) { + if (m.id < 0 || m.id >= piece_scores->size()) { + return; + } LatticeElement& target_element = lattice[i + m.match_length]; const float score = lattice[i].score + (*piece_scores)[m.id]; if (target_element.prev_position < 0 || target_element.score < score) { diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc index 3338ca38e..7e35041f6 100644 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc +++ b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h" #include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h" @@ -60,6 +61,12 @@ class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel { for (int i = 0; i < num_of_sentences; i++) { // Create a vector of int32 from input according to spans. const int split_size = input_splits_flat(i + 1) - input_splits_flat(i); + OP_REQUIRES( + ctx, + split_size >= 0 && + (input_offset + split_size) <= input_values_flat.size(), + errors::InvalidArgument("input_splits must be monotonically " + "non-decreasing and within bounds.")); codes_for_split.clear(); codes_for_split.reserve(split_size); for (int j = 0; j < split_size; ++j) { diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc index 89e8a4723..3f8f6df4d 100644 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc +++ b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc @@ -85,6 +85,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { for (int i = 0; i < num_of_sentences; i++) { // Create a vector of int32 from input according to spans. const int split_size = input_splits_data[i + 1] - input_splits_data[i]; + TF_LITE_ENSURE_MSG( + context, + split_size >= 0 && + (input_offset + split_size) <= NumElements(input_encoded.dims), + "input_splits must be monotonically non-decreasing and " + "within bounds."); codes_for_split.clear(); std::copy(input_encoded_data + input_offset, input_encoded_data + input_offset + split_size,