Adding bos eos test

This commit is contained in:
Ita Zaporozhets 2024-06-21 10:09:48 +02:00
parent 2dbb00ce80
commit c490db21f0
2 changed files with 17 additions and 1 deletions

View File

@ -250,7 +250,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
tokens = tokens[1:]
return tokens
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
# Modified from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.

View File

@ -24,6 +24,7 @@ from datasets import load_dataset
from transformers import (
SPIECE_UNDERLINE,
AddedToken,
AutoTokenizer,
LlamaTokenizer,
LlamaTokenizerFast,
)
@ -32,6 +33,7 @@ from transformers.testing_utils import (
get_tests_dir,
nested_simplify,
require_jinja,
require_read_token,
require_sentencepiece,
require_tokenizers,
require_torch,
@ -822,3 +824,17 @@ class CommonSpmIntegrationTests(unittest.TestCase):
self.assertEqual(input_ids, [284, 1, 156])
tokens = self.tokenizer.tokenize("No <s> ▁He")
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
@require_read_token
def test_bos_eos_tokens(self):
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", add_bos_token=False, add_eos_token=True)
assert tokenizer("hello")["input_ids"][0] != tokenizer.bos_token_id # no bos token
assert tokenizer("hello")["input_ids"][-1] == tokenizer.eos_token_id # eos token
tokenizer.add_special_tokens({"eos_token": "<new_eos>"}) # update new eos token
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
assert tokens[-1] == "<new_eos>"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", add_bos_token=True, add_eos_token=True)
assert tokenizer("hello")["input_ids"][0] == tokenizer.bos_token_id # bos token
assert tokenizer("hello")["input_ids"][-1] == tokenizer.eos_token_id # eos token