mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Merge 02e75eb37d
into 2d561713f8
This commit is contained in:
commit
e89b1e63fa
@ -96,6 +96,8 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers
|
|||||||
- batch_decode
|
- batch_decode
|
||||||
- decode
|
- decode
|
||||||
- encode
|
- encode
|
||||||
|
- add_bos_token
|
||||||
|
- add_eos_token
|
||||||
- push_to_hub
|
- push_to_hub
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
@ -990,6 +990,8 @@ class SpecialTokensMixin:
|
|||||||
|
|
||||||
# if we are adding tokens that were not part of the vocab, we ought to add them
|
# if we are adding tokens that were not part of the vocab, we ought to add them
|
||||||
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
|
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
|
||||||
|
if hasattr(self, "update_post_processor"):
|
||||||
|
self.update_post_processor()
|
||||||
return added_tokens
|
return added_tokens
|
||||||
|
|
||||||
def add_tokens(
|
def add_tokens(
|
||||||
|
@ -26,6 +26,7 @@ from typing import Any, Optional, Union
|
|||||||
import tokenizers.pre_tokenizers as pre_tokenizers_fast
|
import tokenizers.pre_tokenizers as pre_tokenizers_fast
|
||||||
from tokenizers import Encoding as EncodingFast
|
from tokenizers import Encoding as EncodingFast
|
||||||
from tokenizers import Tokenizer as TokenizerFast
|
from tokenizers import Tokenizer as TokenizerFast
|
||||||
|
from tokenizers import processors
|
||||||
from tokenizers.decoders import Decoder as DecoderFast
|
from tokenizers.decoders import Decoder as DecoderFast
|
||||||
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
|
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
|
||||||
|
|
||||||
@ -174,8 +175,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
kwargs.setdefault("max_length", _padding["length"])
|
kwargs.setdefault("max_length", _padding["length"])
|
||||||
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
|
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
|
||||||
|
|
||||||
|
self._add_bos_token = kwargs.get("add_bos_token", None)
|
||||||
|
self._add_eos_token = kwargs.get("add_eos_token", None)
|
||||||
|
|
||||||
# We call this after having initialized the backend tokenizer because we update it.
|
# We call this after having initialized the backend tokenizer because we update it.
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if "add_bos_token" in kwargs or "add_eos_token" in kwargs:
|
||||||
|
self.update_post_processor()
|
||||||
|
|
||||||
|
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
|
||||||
self._tokenizer.encode_special_tokens = self.split_special_tokens
|
self._tokenizer.encode_special_tokens = self.split_special_tokens
|
||||||
|
|
||||||
added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
|
added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
|
||||||
@ -920,3 +929,58 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
kwargs["additional_special_tokens"] = additional_special_tokens
|
kwargs["additional_special_tokens"] = additional_special_tokens
|
||||||
|
|
||||||
return self.__class__(tokenizer_object=tokenizer, **kwargs)
|
return self.__class__(tokenizer_object=tokenizer, **kwargs)
|
||||||
|
|
||||||
|
def update_post_processor(self):
|
||||||
|
"""
|
||||||
|
Overwrites the underlying post processor with the current `bos_token` and `eos_token`.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._tokenizer.post_processor, processors.TemplateProcessing) and not isinstance(
|
||||||
|
self._tokenizer.post_processor, processors.Sequence
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warn(
|
||||||
|
"Warning overwriting the original postProcessor in order to update `bos_token` or `eos_token`. "
|
||||||
|
"Reload the tokenizer without these parameters if that is not desired"
|
||||||
|
)
|
||||||
|
|
||||||
|
bos = self.bos_token
|
||||||
|
bos_token_id = self.bos_token_id
|
||||||
|
if bos is None and self.add_bos_token:
|
||||||
|
raise ValueError("add_bos_token = True but bos_token = None")
|
||||||
|
|
||||||
|
eos = self.eos_token
|
||||||
|
eos_token_id = self.eos_token_id
|
||||||
|
if eos is None and self.add_eos_token:
|
||||||
|
raise ValueError("add_eos_token = True but eos_token = None")
|
||||||
|
|
||||||
|
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
||||||
|
pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
|
||||||
|
|
||||||
|
special_tokens = []
|
||||||
|
if self.add_bos_token:
|
||||||
|
special_tokens.append((bos, bos_token_id))
|
||||||
|
if self.add_eos_token:
|
||||||
|
special_tokens.append((eos, eos_token_id))
|
||||||
|
if special_tokens:
|
||||||
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||||
|
single=single, pair=pair, special_tokens=special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def add_eos_token(self):
|
||||||
|
return self._add_eos_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def add_bos_token(self):
|
||||||
|
return self._add_bos_token
|
||||||
|
|
||||||
|
@add_eos_token.setter
|
||||||
|
def add_eos_token(self, value):
|
||||||
|
self._add_eos_token = value
|
||||||
|
self.update_post_processor()
|
||||||
|
|
||||||
|
@add_bos_token.setter
|
||||||
|
def add_bos_token(self, value):
|
||||||
|
self._add_bos_token = value
|
||||||
|
self.update_post_processor()
|
||||||
|
@ -839,6 +839,34 @@ class CommonSpmIntegrationTests(unittest.TestCase):
|
|||||||
tokens = self.tokenizer.tokenize("No <s> ▁He")
|
tokens = self.tokenizer.tokenize("No <s> ▁He")
|
||||||
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
|
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
|
def test_bos_eos_tokens(self):
|
||||||
|
new_eos_token = "<new_eos>"
|
||||||
|
model_path = "hf-internal-testing/llama-3-8b-internal"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, add_bos_token=False, add_eos_token=True)
|
||||||
|
self.assertNotEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id) # no bos token
|
||||||
|
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id) # eos token
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
|
||||||
|
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
|
||||||
|
self.assertEqual(tokens[-1], new_eos_token)
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer("hello")["input_ids"][0], tokenizer.bos_token_id)
|
||||||
|
self.assertEqual(tokenizer("hello")["input_ids"][-1], tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens({"eos_token": new_eos_token}) # update new eos token
|
||||||
|
tokens = tokenizer.tokenize("hello", add_special_tokens=True)
|
||||||
|
self.assertEqual(tokens[-1], new_eos_token)
|
||||||
|
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
|
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast))
|
||||||
|
tokens = tokenizer_reload.tokenize("hello", add_special_tokens=True)
|
||||||
|
self.assertEqual(tokens[-1], new_eos_token)
|
||||||
|
shutil.rmtree(tmpdirname)
|
||||||
|
|
||||||
|
|
||||||
@require_tiktoken
|
@require_tiktoken
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
Loading…
Reference in New Issue
Block a user