diff --git a/seqio/test_utils.py b/seqio/test_utils.py index 946c7421..1830d79a 100644 --- a/seqio/test_utils.py +++ b/seqio/test_utils.py @@ -1265,12 +1265,14 @@ def sentencepiece_vocab( sentencepiece_model_pb2.NormalizerSpec ] = None, reverse_extra_ids: bool = True, + use_fast_tokenizer: bool = False, ): return vocabularies.SentencePieceVocabulary( os.path.join(TEST_DATA_DIR, "sentencepiece", "sentencepiece.model"), extra_ids=extra_ids, normalizer_spec_overrides=normalizer_spec_overrides, reverse_extra_ids=reverse_extra_ids, + use_fast_tokenizer=use_fast_tokenizer, ) diff --git a/seqio/vocabularies.py b/seqio/vocabularies.py index 85799271..658d7530 100644 --- a/seqio/vocabularies.py +++ b/seqio/vocabularies.py @@ -285,6 +285,7 @@ def __init__( sentencepiece_model_pb2.NormalizerSpec ] = None, reverse_extra_ids: bool = True, + use_fast_tokenizer: bool = False, ): """Create a SentencePieceVocabulary. @@ -300,11 +301,14 @@ def __init__( reverse_extra_ids: if True, extra_ids are numbered in descending order, so the first extra_id has the highest number. This is done for compatibility with span_corruption mask generation in T5. + use_fast_tokenizer: use the tf_text fastsentencepiecetokenizer + implementation which runs much faster. """ self._sentencepiece_model_file = sentencepiece_model_file self._normalizer_spec_overrides = normalizer_spec_overrides self._reverse_extra_ids = reverse_extra_ids self._model: Optional[SentencePieceVocabulary._ModelContext] = None + self._use_fast_tokenizer = use_fast_tokenizer super().__init__(extra_ids=extra_ids) @@ -436,6 +440,8 @@ def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor: @property def tf_tokenizer(self): """Instantiate and return a TF tokenizer.""" + if self._use_fast_tokenizer: + return tf_text.FastSentencepieceTokenizer(model=self.sp_model) return tf_text.SentencepieceTokenizer(model=self.sp_model) @property diff --git a/seqio/vocabularies_test.py b/seqio/vocabularies_test.py index 51883718..a8077057 100644 --- a/seqio/vocabularies_test.py +++ b/seqio/vocabularies_test.py @@ -298,6 +298,20 @@ def test_extra_ids(self): test_tokens, tuple(vocab.encode_tf(test_string).numpy()) ) + def test_fast_tokenizer(self): + vocab = test_utils.sentencepiece_vocab( + extra_ids=10, use_fast_tokenizer=True) + self.assertEqual(36, vocab.vocab_size) + self.assertEqual("v", vocab.decode([25])) + test_string = " v " + test_tokens = (35, 34, 3, 25, 26) + self.assertEqual(test_string, vocab.decode(test_tokens)) + self.assertEqual(test_string, _decode_tf(vocab, test_tokens)) + self.assertSequenceEqual(test_tokens, vocab.encode(test_string)) + self.assertSequenceEqual( + test_tokens, tuple(vocab.encode_tf(test_string).numpy()) + ) + def test_force_repeated_whitespace_preservation(self): test_string = "a a a a" # string with repeated whitespaces