This commit is contained in:
Ita Zaporozhets 2025-07-02 21:46:03 +02:00 committed by GitHub
commit e89b1e63fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 0 deletions

View File

@ -96,6 +96,8 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers
- batch_decode - batch_decode
- decode - decode
- encode - encode
- add_bos_token
- add_eos_token
- push_to_hub - push_to_hub
- all - all

View File

@ -990,6 +990,8 @@ class SpecialTokensMixin:
# if we are adding tokens that were not part of the vocab, we ought to add them # if we are adding tokens that were not part of the vocab, we ought to add them
added_tokens = self.add_tokens(added_tokens, special_tokens=True) added_tokens = self.add_tokens(added_tokens, special_tokens=True)
if hasattr(self, "update_post_processor"):
self.update_post_processor()
return added_tokens return added_tokens
def add_tokens( def add_tokens(

View File

@ -26,6 +26,7 @@ from typing import Any, Optional, Union
import tokenizers.pre_tokenizers as pre_tokenizers_fast import tokenizers.pre_tokenizers as pre_tokenizers_fast
from tokenizers import Encoding as EncodingFast from tokenizers import Encoding as EncodingFast
from tokenizers import Tokenizer as TokenizerFast from tokenizers import Tokenizer as TokenizerFast
from tokenizers import processors
from tokenizers.decoders import Decoder as DecoderFast from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
@ -174,8 +175,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
kwargs.setdefault("max_length", _padding["length"]) kwargs.setdefault("max_length", _padding["length"])
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"]) kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
self._add_bos_token = kwargs.get("add_bos_token", None)
self._add_eos_token = kwargs.get("add_eos_token", None)
# We call this after having initialized the backend tokenizer because we update it. # We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs) super().__init__(**kwargs)
if "add_bos_token" in kwargs or "add_eos_token" in kwargs:
self.update_post_processor()
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
self._tokenizer.encode_special_tokens = self.split_special_tokens self._tokenizer.encode_special_tokens = self.split_special_tokens
added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder} added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
@ -920,3 +929,58 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
kwargs["additional_special_tokens"] = additional_special_tokens kwargs["additional_special_tokens"] = additional_special_tokens
return self.__class__(tokenizer_object=tokenizer, **kwargs) return self.__class__(tokenizer_object=tokenizer, **kwargs)
def update_post_processor(self):
"""
Overwrites the underlying post processor with the current `bos_token` and `eos_token`.
"""
if not isinstance(self._tokenizer.post_processor, processors.TemplateProcessing) and not isinstance(
self._tokenizer.post_processor, processors.Sequence
):
return
logger.warn(
"Warning overwriting the original postProcessor in order to update `bos_token` or `eos_token`. "
"Reload the tokenizer without these parameters if that is not desired"
)
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
if special_tokens:
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()

View File

@ -839,6 +839,34 @@ class CommonSpmIntegrationTests(unittest.TestCase):
tokens = self.tokenizer.tokenize("No <s> ▁He") tokens = self.tokenizer.tokenize("No <s> ▁He")
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
@require_read_token
def test_bos_eos_tokens(self):
new_eos_token = "<new_eos>"
model_path = "hf-internal-testing/llama-3-8b-internal"
tokenizer = AutoTokenizer.from_pretrained(model_path, add_bos_token=False, add_eos_token=True)
self.assertNotEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id) # no bos token
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id) # eos token
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
self.assertEqual(tokens[-1], new_eos_token)
self.assertEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id)
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id)
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
self.assertEqual(tokens[-1], new_eos_token)
tmpdirname = tempfile.mkdtemp()
tokenizer.save_pretrained(tmpdirname)
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname)
self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast))
tokens = tokenizer_reload.tokenize("hello", add_special_tokens=True)
self.assertEqual(tokens[-1], new_eos_token)
shutil.rmtree(tmpdirname)
@require_tiktoken @require_tiktoken
@require_read_token @require_read_token