mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 02e75eb37d
into 2d561713f8
This commit is contained in:
commit
e89b1e63fa
@ -96,6 +96,8 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers
|
||||
- batch_decode
|
||||
- decode
|
||||
- encode
|
||||
- add_bos_token
|
||||
- add_eos_token
|
||||
- push_to_hub
|
||||
- all
|
||||
|
||||
|
@ -990,6 +990,8 @@ class SpecialTokensMixin:
|
||||
|
||||
# if we are adding tokens that were not part of the vocab, we ought to add them
|
||||
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
|
||||
if hasattr(self, "update_post_processor"):
|
||||
self.update_post_processor()
|
||||
return added_tokens
|
||||
|
||||
def add_tokens(
|
||||
|
@ -26,6 +26,7 @@ from typing import Any, Optional, Union
|
||||
import tokenizers.pre_tokenizers as pre_tokenizers_fast
|
||||
from tokenizers import Encoding as EncodingFast
|
||||
from tokenizers import Tokenizer as TokenizerFast
|
||||
from tokenizers import processors
|
||||
from tokenizers.decoders import Decoder as DecoderFast
|
||||
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
|
||||
|
||||
@ -174,8 +175,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
kwargs.setdefault("max_length", _padding["length"])
|
||||
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
|
||||
|
||||
self._add_bos_token = kwargs.get("add_bos_token", None)
|
||||
self._add_eos_token = kwargs.get("add_eos_token", None)
|
||||
|
||||
# We call this after having initialized the backend tokenizer because we update it.
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if "add_bos_token" in kwargs or "add_eos_token" in kwargs:
|
||||
self.update_post_processor()
|
||||
|
||||
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
|
||||
self._tokenizer.encode_special_tokens = self.split_special_tokens
|
||||
|
||||
added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
|
||||
@ -920,3 +929,58 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
kwargs["additional_special_tokens"] = additional_special_tokens
|
||||
|
||||
return self.__class__(tokenizer_object=tokenizer, **kwargs)
|
||||
|
||||
def update_post_processor(self):
|
||||
"""
|
||||
Overwrites the underlying post processor with the current `bos_token` and `eos_token`.
|
||||
"""
|
||||
if not isinstance(self._tokenizer.post_processor, processors.TemplateProcessing) and not isinstance(
|
||||
self._tokenizer.post_processor, processors.Sequence
|
||||
):
|
||||
return
|
||||
|
||||
logger.warn(
|
||||
"Warning overwriting the original postProcessor in order to update `bos_token` or `eos_token`. "
|
||||
"Reload the tokenizer without these parameters if that is not desired"
|
||||
)
|
||||
|
||||
bos = self.bos_token
|
||||
bos_token_id = self.bos_token_id
|
||||
if bos is None and self.add_bos_token:
|
||||
raise ValueError("add_bos_token = True but bos_token = None")
|
||||
|
||||
eos = self.eos_token
|
||||
eos_token_id = self.eos_token_id
|
||||
if eos is None and self.add_eos_token:
|
||||
raise ValueError("add_eos_token = True but eos_token = None")
|
||||
|
||||
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
||||
pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
|
||||
|
||||
special_tokens = []
|
||||
if self.add_bos_token:
|
||||
special_tokens.append((bos, bos_token_id))
|
||||
if self.add_eos_token:
|
||||
special_tokens.append((eos, eos_token_id))
|
||||
if special_tokens:
|
||||
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||
single=single, pair=pair, special_tokens=special_tokens
|
||||
)
|
||||
|
||||
@property
|
||||
def add_eos_token(self):
|
||||
return self._add_eos_token
|
||||
|
||||
@property
|
||||
def add_bos_token(self):
|
||||
return self._add_bos_token
|
||||
|
||||
@add_eos_token.setter
|
||||
def add_eos_token(self, value):
|
||||
self._add_eos_token = value
|
||||
self.update_post_processor()
|
||||
|
||||
@add_bos_token.setter
|
||||
def add_bos_token(self, value):
|
||||
self._add_bos_token = value
|
||||
self.update_post_processor()
|
||||
|
@ -839,6 +839,34 @@ class CommonSpmIntegrationTests(unittest.TestCase):
|
||||
tokens = self.tokenizer.tokenize("No <s> ▁He")
|
||||
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
|
||||
|
||||
@require_read_token
|
||||
def test_bos_eos_tokens(self):
|
||||
new_eos_token = "<new_eos>"
|
||||
model_path = "hf-internal-testing/llama-3-8b-internal"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, add_bos_token=False, add_eos_token=True)
|
||||
self.assertNotEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id) # no bos token
|
||||
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id) # eos token
|
||||
|
||||
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
|
||||
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
|
||||
self.assertEqual(tokens[-1], new_eos_token)
|
||||
|
||||
self.assertEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id)
|
||||
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id)
|
||||
|
||||
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
|
||||
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
|
||||
self.assertEqual(tokens[-1], new_eos_token)
|
||||
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast))
|
||||
tokens = tokenizer_reload.tokenize("hello", add_special_tokens=True)
|
||||
self.assertEqual(tokens[-1], new_eos_token)
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
|
||||
@require_tiktoken
|
||||
@require_read_token
|
||||
|
Loading…
Reference in New Issue
Block a user