This commit is contained in:
Ita Zaporozhets 2025-07-02 21:46:03 +02:00 committed by GitHub
commit e89b1e63fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 0 deletions

View File

@ -96,6 +96,8 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers
- batch_decode
- decode
- encode
- add_bos_token
- add_eos_token
- push_to_hub
- all

View File

@ -990,6 +990,8 @@ class SpecialTokensMixin:
# if we are adding tokens that were not part of the vocab, we ought to add them
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
if hasattr(self, "update_post_processor"):
self.update_post_processor()
return added_tokens
def add_tokens(

View File

@ -26,6 +26,7 @@ from typing import Any, Optional, Union
import tokenizers.pre_tokenizers as pre_tokenizers_fast
from tokenizers import Encoding as EncodingFast
from tokenizers import Tokenizer as TokenizerFast
from tokenizers import processors
from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
@ -174,8 +175,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
kwargs.setdefault("max_length", _padding["length"])
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
self._add_bos_token = kwargs.get("add_bos_token", None)
self._add_eos_token = kwargs.get("add_eos_token", None)
# We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs)
if "add_bos_token" in kwargs or "add_eos_token" in kwargs:
self.update_post_processor()
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
self._tokenizer.encode_special_tokens = self.split_special_tokens
added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
@ -920,3 +929,58 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
kwargs["additional_special_tokens"] = additional_special_tokens
return self.__class__(tokenizer_object=tokenizer, **kwargs)
def update_post_processor(self):
"""
Overwrites the underlying post processor with the current `bos_token` and `eos_token`.
"""
if not isinstance(self._tokenizer.post_processor, processors.TemplateProcessing) and not isinstance(
self._tokenizer.post_processor, processors.Sequence
):
return
logger.warn(
"Warning overwriting the original postProcessor in order to update `bos_token` or `eos_token`. "
"Reload the tokenizer without these parameters if that is not desired"
)
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
if special_tokens:
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()

View File

@ -839,6 +839,34 @@ class CommonSpmIntegrationTests(unittest.TestCase):
tokens = self.tokenizer.tokenize("No <s> ▁He")
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
@require_read_token
def test_bos_eos_tokens(self):
new_eos_token = "<new_eos>"
model_path = "hf-internal-testing/llama-3-8b-internal"
tokenizer = AutoTokenizer.from_pretrained(model_path, add_bos_token=False, add_eos_token=True)
self.assertNotEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id) # no bos token
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id) # eos token
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
self.assertEqual(tokens[-1], new_eos_token)
self.assertEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id)
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id)
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
self.assertEqual(tokens[-1], new_eos_token)
tmpdirname = tempfile.mkdtemp()
tokenizer.save_pretrained(tmpdirname)
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname)
self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast))
tokens = tokenizer_reload.tokenize("hello", add_special_tokens=True)
self.assertEqual(tokens[-1], new_eos_token)
shutil.rmtree(tmpdirname)
@require_tiktoken
@require_read_token