mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 09:42:22 +06:00
Adding bos eos test
This commit is contained in:
parent
2dbb00ce80
commit
c490db21f0
@ -250,7 +250,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
tokens = tokens[1:]
|
||||
return tokens
|
||||
|
||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
|
||||
# Modified from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
|
||||
def _tokenize(self, text, **kwargs):
|
||||
"""
|
||||
Returns a tokenized string.
|
||||
|
@ -24,6 +24,7 @@ from datasets import load_dataset
|
||||
from transformers import (
|
||||
SPIECE_UNDERLINE,
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
)
|
||||
@ -32,6 +33,7 @@ from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
require_jinja,
|
||||
require_read_token,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
@ -822,3 +824,17 @@ class CommonSpmIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(input_ids, [284, 1, 156])
|
||||
tokens = self.tokenizer.tokenize("No <s> ▁He")
|
||||
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
|
||||
|
||||
@require_read_token
|
||||
def test_bos_eos_tokens(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", add_bos_token=False, add_eos_token=True)
|
||||
assert tokenizer("hello")["input_ids"][0] != tokenizer.bos_token_id # no bos token
|
||||
assert tokenizer("hello")["input_ids"][-1] == tokenizer.eos_token_id # eos token
|
||||
|
||||
tokenizer.add_special_tokens({"eos_token": "<new_eos>"}) # update new eos token
|
||||
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
|
||||
assert tokens[-1] == "<new_eos>"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", add_bos_token=True, add_eos_token=True)
|
||||
assert tokenizer("hello")["input_ids"][0] == tokenizer.bos_token_id # bos token
|
||||
assert tokenizer("hello")["input_ids"][-1] == tokenizer.eos_token_id # eos token
|
||||
|
Loading…
Reference in New Issue
Block a user