From a9b7a53aafa8b7a63df60c130a637c9c037333d4 Mon Sep 17 00:00:00 2001 From: "TF.Text Team" Date: Fri, 8 May 2026 10:48:03 -0700 Subject: [PATCH] Harden FastWordpieceTokenizer, FastBertNormalizer, and PhraseTokenizer with bounds verification and initialization checks: * Prevent null pointer dereferences by validating required FlatBuffer fields. * Prevent OOB reads during detokenization by checking token IDs against vocab size. * Secure FastBertNormalizer against adversarial heap slicing by tracking string pool limits. PiperOrigin-RevId: 912596984 --- tensorflow_text/core/kernels/BUILD | 3 ++ .../core/kernels/fast_bert_normalizer.h | 31 +++++++++++++++++-- .../core/kernels/fast_wordpiece_tokenizer.cc | 25 +++++++++++++++ .../core/kernels/phrase_tokenizer.cc | 20 ++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index 57527e275..f5b968df3 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -211,6 +211,7 @@ tf_cc_library( ":darts_clone_trie_builder", ":darts_clone_trie_wrapper", ":fast_bert_normalizer_model", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -1254,9 +1255,11 @@ cc_library( ":string_vocab", ":whitespace_tokenizer", ":whitespace_tokenizer_config_builder", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", # lite/kernels/shim:status_macros tensorflow dep, diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer.h b/tensorflow_text/core/kernels/fast_bert_normalizer.h index efd5102bf..c7b8f1849 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer.h @@ -15,9 +15,12 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_ +#include #include #include +#include "absl/base/optimization.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "icu4c/source/common/unicode/utf8.h" #include "tensorflow/lite/kernels/shim/status_macros.h" @@ -83,7 +86,12 @@ class FastBertNormalizer { // lifetime of the instance. static absl::StatusOr Create( const uint32_t* trie_data, int data_for_codepoint_zero, - const char* normalized_string_pool) { + const char* normalized_string_pool, + size_t normalized_string_pool_size = static_cast(-1)) { + if (trie_data == nullptr || normalized_string_pool == nullptr) { + return absl::InvalidArgumentError( + "trie_data or normalized_string_pool is null"); + } FastBertNormalizer result; SH_ASSIGN_OR_RETURN(auto trie, trie_utils::DartsCloneTrieWrapper::Create(trie_data)); @@ -92,6 +100,7 @@ class FastBertNormalizer { result.data_for_codepoint_zero_ = data_for_codepoint_zero; result.normalized_string_pool_ = reinterpret_cast(normalized_string_pool); + result.normalized_string_pool_size_ = normalized_string_pool_size; return result; } @@ -103,11 +112,20 @@ class FastBertNormalizer { // through the lifetime of the instance. static absl::StatusOr Create( const void* model_flatbuffer) { + if (model_flatbuffer == nullptr) { + return absl::InvalidArgumentError("model_flatbuffer is null"); + } // `GetFastBertNormalizerModel()` is autogenerated by flatbuffer. auto model = GetFastBertNormalizerModel(model_flatbuffer); + if (model == nullptr || model->trie_array() == nullptr || + model->normalized_string_pool() == nullptr) { + return absl::InvalidArgumentError( + "FastBertNormalizerModel or its required fields are null"); + } return Create( model->trie_array()->data(), model->data_for_codepoint_zero(), - reinterpret_cast(model->normalized_string_pool()->data())); + reinterpret_cast(model->normalized_string_pool()->data()), + model->normalized_string_pool()->size()); } // Normalizes the input based on config `lower_case_nfd_strip_accents`. @@ -290,6 +308,12 @@ class FastBertNormalizer { } const int offset = (data & text_norm::kNormalizedStringOffsetMask) >> text_norm::kBitsToEncodeUtf8LengthOfNormalizedString; + if (ABSL_PREDICT_FALSE( + offset < 0 || + (normalized_string_pool_size_ != static_cast(-1) && + offset + len > normalized_string_pool_size_))) { + return ""; + } return absl::string_view(normalized_string_pool_ + offset, len); } @@ -331,6 +355,9 @@ class FastBertNormalizer { // The string pool of normalized strings. Each normalized string is a // substring denoted by (offset and length). const char* normalized_string_pool_; + + // The size of normalized_string_pool_ if known, or -1. + size_t normalized_string_pool_size_ = static_cast(-1); }; } // namespace text diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc index 88feee121..ff10541f4 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc @@ -17,9 +17,11 @@ #include #include "absl/base/attributes.h" +#include "absl/base/optimization.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "icu4c/source/common/unicode/uchar.h" @@ -48,6 +50,11 @@ FastWordpieceTokenizer::Create(const void* config_flatbuffer) { FastWordpieceTokenizer tokenizer; // `GetFastWordpieceTokenizerConfig()` is autogenerated by flatbuffer. tokenizer.config_ = GetFastWordpieceTokenizerConfig(config_flatbuffer); + if (tokenizer.config_ == nullptr || + tokenizer.config_->trie_array() == nullptr) { + return absl::InvalidArgumentError( + "FastWordpieceTokenizerConfig or its trie_array is null."); + } auto trie_or = trie_utils::DartsCloneTrieWrapper::Create( tokenizer.config_->trie_array()->data()); if (!trie_or.ok()) { @@ -127,8 +134,23 @@ FastWordpieceTokenizer::DetokenizeToTokens( "true in the config flatbuffer. Please rebuild the model flatbuffer " "by setting support_detokenization=true."); } + if (config_->vocab_array() == nullptr || + config_->vocab_is_suffix_array() == nullptr) { + return absl::InternalError( + "Missing vocab_array or vocab_is_suffix_array in config."); + } + const int vocab_size = config_->vocab_array()->size(); + const int is_suffix_size = config_->vocab_is_suffix_array()->size(); for (int id : input) { + if (ABSL_PREDICT_FALSE(id < 0 || id >= vocab_size || + id >= is_suffix_size)) { + return absl::OutOfRangeError( + absl::StrCat("Token ID out of bounds: ", id)); + } auto vocab = config_->vocab_array()->Get(id); + if (ABSL_PREDICT_FALSE(vocab == nullptr)) { + return absl::InternalError("Null vocab string in vocab_array."); + } auto is_suffix = config_->vocab_is_suffix_array()->Get(id); if (!subwords.empty() && !is_suffix) { // When current subword is not a suffix token, it marks the start of a new @@ -140,6 +162,9 @@ FastWordpieceTokenizer::DetokenizeToTokens( // Special case: when a suffix token e.g. "##a" appears at the start of the // input ids, we preserve the suffix_indicator. if (subwords.empty() && is_suffix) { + if (ABSL_PREDICT_FALSE(config_->suffix_indicator() == nullptr)) { + return absl::InternalError("Missing suffix_indicator in config."); + } subwords.emplace_back(config_->suffix_indicator()->string_view()); } subwords.emplace_back(vocab->string_view()); diff --git a/tensorflow_text/core/kernels/phrase_tokenizer.cc b/tensorflow_text/core/kernels/phrase_tokenizer.cc index cfffe87fe..98abddd08 100644 --- a/tensorflow_text/core/kernels/phrase_tokenizer.cc +++ b/tensorflow_text/core/kernels/phrase_tokenizer.cc @@ -20,7 +20,10 @@ #include #include +#include "absl/base/optimization.h" +#include "absl/status/status.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/kernels/shim/status_macros.h" @@ -34,6 +37,12 @@ namespace text { PhraseTokenizer tokenizer; // `GetPhraseTokenizerConfig()` is autogenerated by flatbuffer. tokenizer.phrase_config_ = GetPhraseTokenizerConfig(config_flatbuffer); + if (tokenizer.phrase_config_ == nullptr || + tokenizer.phrase_config_->vocab_trie() == nullptr || + tokenizer.phrase_config_->whitespace_config() == nullptr) { + return absl::InvalidArgumentError( + "PhraseTokenizerConfig or required fields are null."); + } tokenizer.trie_ = absl::make_unique( tokenizer.phrase_config_->vocab_trie()->nodes()); tokenizer.prob_ = static_cast(tokenizer.phrase_config_->prob()) / 100; @@ -174,8 +183,19 @@ absl::StatusOr> PhraseTokenizer::DetokenizeToTokens( "true in the config flatbuffer. Please rebuild the model flatbuffer " "by setting support_detokenization=true."); } + if (phrase_config_->vocab_array() == nullptr) { + return absl::InternalError("Missing vocab_array in config."); + } + const int vocab_size = phrase_config_->vocab_array()->size(); for (int id : input) { + if (ABSL_PREDICT_FALSE(id < 0 || id >= vocab_size)) { + return absl::OutOfRangeError( + absl::StrCat("Token ID out of bounds: ", id)); + } auto vocab = phrase_config_->vocab_array()->Get(id); + if (ABSL_PREDICT_FALSE(vocab == nullptr)) { + return absl::InternalError("Null vocab string in vocab_array."); + } output_tokens.emplace_back(vocab->string_view()); } return output_tokens;