[SPM] Patch spm Llama and T5 (#25656)

* hot fix

* only encode with string prefix if starts with prefix

* styling

* add a new test

* fixup
This commit is contained in:
Arthur 2023-08-23 07:16:43 +02:00 committed by GitHub
parent 57943630e2
commit 51794bf21e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 12 deletions

View File

@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer):
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
if self.legacy:
return self.sp_model.encode(text, out_type=str)
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
text = self.unk_token + text
tokens = self.sp_model.encode(text, out_type=str)
return tokens[unk_token_length:]
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
return tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""

View File

@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer):
tokens = tokens[1:]
return tokens
@property
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer):
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
if self.legacy:
return self.sp_model.encode(text, out_type=str)
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
text = self.unk_token + text
tokens = self.sp_model.encode(text, out_type=str)
return tokens[unk_token_length:]
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
return tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""

View File

@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase):
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
def test_some_edge_cases(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
self.assertEqual(sp_tokens, ["<", "s", ">>"])
tokens = tokenizer.tokenize("<s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["<s>", ">"])
@require_sentencepiece
@require_tokenizers