Fixing a pathological case for slow tokenizers (#14981)

* Fixing a pathological case for slow tokenizers

* Update src/transformers/tokenization_utils.py
This commit is contained in:
Nicolas Patry 2021-12-30 09:10:34 +01:00 committed by GitHub
parent d1ba56d8d8
commit d7d60df0ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 2 deletions

View File

@ -131,7 +131,7 @@ class Trie:
# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
skip = None
skip = 0
# Main loop, Giving this algorithm O(n) complexity
for current, current_char in enumerate(text):
if skip and current < skip:
@ -175,6 +175,11 @@ class Trie:
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
@ -219,7 +224,7 @@ class Trie:
# If this character is a starting character within the trie
# start keeping track of this partial match.
if current_char in self.data:
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]
# We have a cut at the end with states.

View File

@ -3687,6 +3687,13 @@ class TrieTest(unittest.TestCase):
trie.add("C")
self.assertEqual(trie.split("ABC"), ["AB", "C"])
def test_trie_skip(self):
trie = Trie()
trie.add("ABC")
trie.add("B")
trie.add("CD")
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
def test_cut_text_hardening(self):
# Even if the offsets are wrong, we necessarily output correct string
# parts.