From c490db21f0673d4507a2d03ba79d02c0a5dba7df Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets Date: Fri, 21 Jun 2024 10:09:48 +0200 Subject: [PATCH] Adding bos eos test --- .../models/llama/tokenization_llama.py | 2 +- tests/models/llama/test_tokenization_llama.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index cc03c1470ee..2a4487f5da5 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -250,7 +250,7 @@ class LlamaTokenizer(PreTrainedTokenizer): tokens = tokens[1:] return tokens - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + # Modified from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize def _tokenize(self, text, **kwargs): """ Returns a tokenized string. diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index e45149672a8..918cce2f478 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -24,6 +24,7 @@ from datasets import load_dataset from transformers import ( SPIECE_UNDERLINE, AddedToken, + AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, ) @@ -32,6 +33,7 @@ from transformers.testing_utils import ( get_tests_dir, nested_simplify, require_jinja, + require_read_token, require_sentencepiece, require_tokenizers, require_torch, @@ -822,3 +824,17 @@ class CommonSpmIntegrationTests(unittest.TestCase): self.assertEqual(input_ids, [284, 1, 156]) tokens = self.tokenizer.tokenize("No ▁He") self.assertEqual(tokens, ["▁No", "", "▁He"]) # spaces are eaten by rstrip / lstrip + + @require_read_token + def test_bos_eos_tokens(self): + tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", add_bos_token=False, add_eos_token=True) + assert tokenizer("hello")["input_ids"][0] != tokenizer.bos_token_id # no bos token + assert tokenizer("hello")["input_ids"][-1] == tokenizer.eos_token_id # eos token + + tokenizer.add_special_tokens({"eos_token": ""}) # update new eos token + tokens = tokenizer.tokenize("hello", add_special_tokens=True) + assert tokens[-1] == "" + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", add_bos_token=True, add_eos_token=True) + assert tokenizer("hello")["input_ids"][0] == tokenizer.bos_token_id # bos token + assert tokenizer("hello")["input_ids"][-1] == tokenizer.eos_token_id # eos token