mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[SPM
] Patch spm
Llama and T5 (#25656)
* hot fix * only encode with string prefix if starts with prefix * styling * add a new test * fixup
This commit is contained in:
parent
57943630e2
commit
51794bf21e
@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||
"""
|
||||
if self.legacy:
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
||||
text = self.unk_token + text
|
||||
tokens = self.sp_model.encode(text, out_type=str)
|
||||
return tokens[unk_token_length:]
|
||||
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||
return tokens
|
||||
|
||||
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
|
@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
tokens = tokens[1:]
|
||||
return tokens
|
||||
|
||||
@property
|
||||
def unk_token_length(self):
|
||||
return len(self.sp_model.encode(str(self.unk_token)))
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
"""
|
||||
Returns a tokenized string.
|
||||
@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||
"""
|
||||
if self.legacy:
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
||||
text = self.unk_token + text
|
||||
tokens = self.sp_model.encode(text, out_type=str)
|
||||
return tokens[unk_token_length:]
|
||||
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||
return tokens
|
||||
|
||||
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
|
@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||
|
||||
def test_some_edge_cases(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
|
||||
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
|
||||
self.assertEqual(sp_tokens, ["<", "s", ">>"])
|
||||
tokens = tokenizer.tokenize("<s>>")
|
||||
self.assertNotEqual(sp_tokens, tokens)
|
||||
self.assertEqual(tokens, ["<s>", ">"])
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
Loading…
Reference in New Issue
Block a user