diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 8ce3396bdd3..fef4c4a623f 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -131,7 +131,7 @@ class Trie: # This is used by the lookahead which needs to skip over # some text where the full match exceeded the place in the initial # for loop - skip = None + skip = 0 # Main loop, Giving this algorithm O(n) complexity for current, current_char in enumerate(text): if skip and current < skip: @@ -175,6 +175,11 @@ class Trie: lookahead_index = current end = current next_char = text[lookahead_index] if lookahead_index < len(text) else None + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + while next_char in looktrie_pointer: looktrie_pointer = looktrie_pointer[next_char] lookahead_index += 1 @@ -219,7 +224,7 @@ class Trie: # If this character is a starting character within the trie # start keeping track of this partial match. - if current_char in self.data: + if current >= skip and current_char in self.data: states[current] = self.data[current_char] # We have a cut at the end with states. diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 54748b3e862..112eab82518 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3687,6 +3687,13 @@ class TrieTest(unittest.TestCase): trie.add("C") self.assertEqual(trie.split("ABC"), ["AB", "C"]) + def test_trie_skip(self): + trie = Trie() + trie.add("ABC") + trie.add("B") + trie.add("CD") + self.assertEqual(trie.split("ABCD"), ["ABC", "D"]) + def test_cut_text_hardening(self): # Even if the offsets are wrong, we necessarily output correct string # parts.