[LlamaTokenizerFast] Refactor default llama (#28881)

* push legacy to fast as well

* super strange

* Update src/transformers/convert_slow_tokenizer.py

* make sure we are BC

* fix Llama test

* nit

* revert

* more test

* style

* update

* small update w.r.t tokenizers

* nit

* don't split

* lol

* add a test for `add_prefix_space=False`

* fix gemma tokenizer as well

* update

* fix gemma

* nicer failures

* fixup

* update

* fix the example for legacy = False

* use `huggyllama/llama-7b` for the PR doctest

* nit

* use from_slow

* fix llama
This commit is contained in:
Arthur 2024-04-23 23:12:59 +02:00 committed by GitHub
parent 12c39e5693
commit e34da3ee3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 142 additions and 31 deletions

View File

@ -105,7 +105,7 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor):
# there is a missing token in the vocab. We have to do this to support merges # there is a missing token in the vocab. We have to do this to support merges
# "<0x09>" is the bytefallback for `\t` # "<0x09>" is the bytefallback for `\t`
vocab["\t"] = vocab.pop("<0x09>") vocab["\t"] = vocab.get("<0x09>")
if vocab_scores is not None: if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True vocab_scores, reverse = dict(vocab_scores), True
@ -1276,7 +1276,7 @@ class GemmaConvert(SpmConverter):
return vocab return vocab
def pre_tokenizer(self, replacement, add_prefix_space): def pre_tokenizer(self, replacement, add_prefix_space):
return None return pre_tokenizers.Split(" ", "merged_with_previous")
def unk_id(self, proto): def unk_id(self, proto):
unk_id = 3 unk_id = 3
@ -1329,7 +1329,7 @@ class GemmaConvert(SpmConverter):
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm" "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
) )
user_defined_symbols = [ user_defined_symbols = [
AddedToken(token, normalized=False, special=False) for token in proto.trainer_spec.user_defined_symbols AddedToken(token, normalized=True, special=False) for token in proto.trainer_spec.user_defined_symbols
] ]
tokenizer.add_tokens(user_defined_symbols) tokenizer.add_tokens(user_defined_symbols)
return tokenizer return tokenizer
@ -1393,14 +1393,18 @@ class LlamaConverter(SpmConverter):
return tokenizer return tokenizer
def normalizer(self, proto): def normalizer(self, proto):
if self.original_tokenizer.legacy:
sequence = [] sequence = []
if hasattr(self.original_tokenizer, "add_prefix_space"): if getattr(self.original_tokenizer, "add_prefix_space"):
if self.original_tokenizer.add_prefix_space:
sequence += [normalizers.Prepend(prepend="")] sequence += [normalizers.Prepend(prepend="")]
sequence += [normalizers.Replace(pattern=" ", content="")] sequence += [normalizers.Replace(pattern=" ", content="")]
return normalizers.Sequence(sequence) return normalizers.Sequence(sequence)
return None # non-legacy, no normalizer
def pre_tokenizer(self, replacement, add_prefix_space): def pre_tokenizer(self, replacement, add_prefix_space):
if not self.original_tokenizer.legacy: # non-legacy, we need a replace
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
return None return None
def post_processor(self): def post_processor(self):

View File

@ -99,30 +99,30 @@ class LlamaTokenizer(PreTrainedTokenizer):
Whether or not to add spaces between special tokens. Whether or not to add spaces between special tokens.
legacy (`bool`, *optional*): legacy (`bool`, *optional*):
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple and #25224 which includes fixes to properly handle tokens that appear after special tokens.
example: Make sure to also set `from_slow` to `True`.
A simple example:
- `legacy=True`: - `legacy=True`:
```python ```python
>>> from transformers import T5Tokenizer >>> from transformers import LlamaTokenizerFast
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True)
>>> tokenizer.encode("Hello <extra_id_0>.") >>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
[8774, 32099, 3, 5, 1] [1, 15043, 29871, 1, 869]
``` ```
- `legacy=False`: - `legacy=False`:
```python ```python
>>> from transformers import T5Tokenizer >>> from transformers import LlamaTokenizerFast
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here >>> tokenizer.encode("Hello <s>.") # 29889 is '.'
[8774, 32099, 5, 1] [1, 15043, 29871, 1, 29889]
``` ```
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
add_prefix_space (`bool`, *optional*, defaults to `True`): add_prefix_space (`bool`, *optional*, defaults to `True`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES

View File

@ -91,7 +91,30 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
add_eos_token (`bool`, *optional*, defaults to `False`): add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences. Whether or not to add an `eos_token` at the end of sequences.
use_default_system_prompt (`bool`, *optional*, defaults to `False`): use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Llama should be used. Whether or not the default system prompt for Llama should be used
legacy (`bool`, *optional*):
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
and #25224 which includes fixes to properly handle tokens that appear after special tokens.
Make sure to also set `from_slow` to `True`.
A simple example:
- `legacy=True`:
```python
>>> from transformers import LlamaTokenizerFast
>>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True)
>>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
[1, 15043, 29871, 1, 869]
```
- `legacy=False`:
```python
>>> from transformers import LlamaTokenizerFast
>>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
>>> tokenizer.encode("Hello <s>.") # 29889 is '.'
[1, 15043, 29871, 1, 29889]
```
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
add_prefix_space (`bool`, *optional*): add_prefix_space (`bool`, *optional*):
Whether or not the tokenizer should automatically add a prefix space Whether or not the tokenizer should automatically add a prefix space
""" """
@ -112,9 +135,21 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
add_bos_token=True, add_bos_token=True,
add_eos_token=False, add_eos_token=False,
use_default_system_prompt=False, use_default_system_prompt=False,
legacy=None,
add_prefix_space=None, add_prefix_space=None,
**kwargs, **kwargs,
): ):
if legacy is None:
logger.warning_once(
f"You are using the default legacy behaviour of the {self.__class__}. This is"
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
)
legacy = True
self.legacy = legacy
if add_prefix_space is not None: if add_prefix_space is not None:
logger.warning_once( logger.warning_once(
"You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers" "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"

View File

@ -30,6 +30,7 @@ from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
nested_simplify, nested_simplify,
require_jinja, require_jinja,
require_read_token,
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
@ -136,11 +137,12 @@ class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertTrue(special_token_id in cr_output) self.assertTrue(special_token_id in cr_output)
@slow @slow
@require_read_token
def test_tokenizer_integration(self): def test_tokenizer_integration(self):
expected_encoding = {'input_ids': [[2, 158434, 591, 84193, 3836, 685, 6599, 31223, 235290, 140247, 578, 6599, 31223, 235290, 145139, 235290, 3491, 235275, 6572, 3311, 235290, 38197, 109959, 591, 25894, 235269, 162174, 235290, 235284, 235269, 1791, 6362, 12481, 235269, 1576, 18622, 235269, 2900, 1136, 86684, 235269, 29092, 4632, 16994, 604, 13146, 14944, 40371, 591, 19700, 235327, 235275, 578, 13146, 14944, 25511, 591, 235300, 12474, 235275, 675, 1163, 235248, 235304, 235284, 235340, 229903, 5377, 575, 235248, 235274, 235276, 235276, 235340, 17044, 578, 5271, 1061, 118345, 1865, 125247, 235269, 8745, 111226, 578, 176888, 235265], [2, 25894, 603, 6869, 577, 953, 235290, 8297, 5271, 209099, 41642, 774, 748, 78253, 2793, 731, 51506, 34346, 611, 2145, 2731, 578, 1833, 4807, 575, 832, 16630, 235265], [2, 651, 4320, 8426, 25341, 36271, 1163, 573, 27894, 5929, 235265]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip expected_encoding = {'input_ids': [[2, 158434, 591, 84193, 3836, 685, 6599, 31223, 235290, 140247, 578, 6599, 31223, 235290, 145139, 235290, 3491, 235275, 6572, 3311, 235290, 38197, 109959, 591, 25894, 235269, 162174, 235290, 235284, 235269, 1791, 6362, 12481, 235269, 1576, 18622, 235269, 2900, 1136, 86684, 235269, 29092, 4632, 16994, 604, 13146, 14944, 40371, 591, 19700, 235327, 235275, 578, 13146, 14944, 25511, 591, 235300, 12474, 235275, 675, 1163, 235248, 235304, 235284, 235340, 229903, 5377, 575, 235248, 235274, 235276, 235276, 235340, 17044, 578, 5271, 1061, 118345, 1865, 125247, 235269, 8745, 111226, 578, 176888, 235265], [2, 25894, 603, 6869, 577, 953, 235290, 8297, 5271, 209099, 41642, 774, 748, 78253, 2793, 731, 51506, 34346, 611, 2145, 2731, 578, 1833, 4807, 575, 832, 16630, 235265], [2, 651, 4320, 8426, 25341, 36271, 1163, 573, 27894, 5929, 235265]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
self.tokenizer_integration_test_util( self.tokenizer_integration_test_util(
expected_encoding=expected_encoding, expected_encoding=expected_encoding,
model_name="hf-internal-testing/dummy-gemma", model_name="google/gemma-2b",
revision="", revision="",
padding=False, padding=False,
) )
@ -318,7 +320,13 @@ class GemmaIntegrationTest(unittest.TestCase):
encoded1 = pyth_tokenizer.encode(string) encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string) encoded2 = rust_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2) self.assertEqual(
encoded1,
encoded2,
msg="Hint: the following tokenization diff were obtained for slow vs fast:\n "
f"elements in slow: {set(pyth_tokenizer.tokenize(string))-set(rust_tokenizer.tokenize(string))} \nvs\n "
f"elements in fast: {set(rust_tokenizer.tokenize(string))-set(pyth_tokenizer.tokenize(string))} \n\n{string}",
)
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True) decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded1, skip_special_tokens=True) decoded2 = rust_tokenizer.decode(encoded1, skip_special_tokens=True)
@ -332,7 +340,7 @@ class GemmaIntegrationTest(unittest.TestCase):
encoded1 = pyth_tokenizer.encode(string) encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string) encoded2 = rust_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2) self.assertEqual(encoded1, encoded2, msg=f"failed on {string}")
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True) decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True) decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)

View File

@ -543,8 +543,15 @@ class LlamaIntegrationTest(unittest.TestCase):
def test_special_token_special_word(self): def test_special_token_special_word(self):
# the word inform should be split as ['in', 'form'] # the word inform should be split as ['in', 'form']
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False) tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False) tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
self.assertEqual(example_inputs, ["<REPR_END>", "in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
# Make sure dummy space is added if it is indeed the first word
example_inputs = tokenizer.tokenize("inform<s>. Hey. .")
self.assertEqual(example_inputs, ["▁inform", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
out1 = tokenizer.decode( out1 = tokenizer.decode(
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
) )
@ -553,12 +560,12 @@ class LlamaIntegrationTest(unittest.TestCase):
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
) )
# decoding strips the added prefix space. # decoding strips the added prefix space.
self.assertEqual(out2, "<REPR_END> inform") self.assertEqual(out2, "<REPR_END>inform")
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False) input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
self.assertEqual(input_ids, [29871, 32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
out2 = tokenizer.decode( out2 = tokenizer.decode(
tokenizer.encode(" <REPR_END> inform", add_special_tokens=False), spaces_between_special_tokens=False tokenizer.encode(" <REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
) )
# TODO @ArthurZ currently we strip left and right, so this will not keep the spaces # TODO @ArthurZ currently we strip left and right, so this will not keep the spaces
self.assertEqual(out2, "<REPR_END>inform") self.assertEqual(out2, "<REPR_END>inform")
@ -575,11 +582,11 @@ class LlamaIntegrationTest(unittest.TestCase):
# Let's make sure that if there are any spaces, we don't remove them! # Let's make sure that if there are any spaces, we don't remove them!
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False) input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
self.assertEqual(input_ids, [259, 1, 15043, 1, 920]) self.assertEqual(input_ids, [29871, 1, 15043, 1, 920])
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False) tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
self.assertEqual(tokens, ["", "<s>", "▁Hello", "<s>", "▁how"]) self.assertEqual(tokens, ["", "<s>", "▁Hello", "<s>", "▁how"])
decoded_tokens = tokenizer.decode(input_ids) decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, " <s> Hello<s> how") self.assertEqual(decoded_tokens, "<s> Hello<s> how")
# Let's make sure the space is preserved # Let's make sure the space is preserved
input_ids = tokenizer.encode("hello", add_special_tokens=True) input_ids = tokenizer.encode("hello", add_special_tokens=True)
@ -594,6 +601,63 @@ class LlamaIntegrationTest(unittest.TestCase):
decoded_tokens = tokenizer.decode(input_ids) decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "hello") self.assertEqual(decoded_tokens, "hello")
def test_no_prefix_space(self):
tokenizer = LlamaTokenizerFast.from_pretrained(
"huggyllama/llama-7b", legacy=False, from_slow=True, add_prefix_space=False
)
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
self.assertEqual(example_inputs, ["<REPR_END>", "in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
# Make sure dummy space is added if it is indeed the first word
example_inputs = tokenizer.tokenize("inform<s>. Hey. .")
self.assertEqual(example_inputs, ["in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
out1 = tokenizer.decode(
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
)
self.assertEqual(out1, "<REPR_END>inform")
out2 = tokenizer.decode(
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
)
# decoding strips the added prefix space.
self.assertEqual(out2, "<REPR_END>inform")
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
out2 = tokenizer.decode(
tokenizer.encode(" <REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
)
self.assertEqual(out2, "<REPR_END>inform")
input_ids = tokenizer.encode("<s> Hello<s>how", add_special_tokens=False)
self.assertEqual(input_ids, [1, 15043, 1, 3525])
tokens = tokenizer.tokenize("<s> Hello<s>how", add_special_tokens=False)
self.assertEqual(tokens, ["<s>", "▁Hello", "<s>", "how"])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "<s> Hello<s>how")
# Let's make sure that if there are any spaces, we don't remove them!
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
self.assertEqual(input_ids, [29871, 1, 15043, 1, 920])
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
self.assertEqual(tokens, ["", "<s>", "▁Hello", "<s>", "▁how"])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
# Let's make sure the space is preserved
input_ids = tokenizer.encode("hello", add_special_tokens=True)
self.assertEqual(input_ids, [1, 12199])
tokens = tokenizer.tokenize("hello")
self.assertEqual(tokens, ["hello"])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "<s>hello")
input_ids = tokenizer.encode("hello", add_special_tokens=False)
self.assertEqual(input_ids, [12199])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "hello")
def test_some_edge_cases(self): def test_some_edge_cases(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False) tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)