mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
cleaning after rebase
This commit is contained in:
parent
e8bfe9051f
commit
f80a9fd479
@ -685,16 +685,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
)
|
||||
# ["This is something", "<special_token_1>", "else"]
|
||||
tokenized_text = []
|
||||
for idx, token in enumerate(tokens):
|
||||
for token in tokens:
|
||||
# Need to skip eventual empty (fully stripped) tokens
|
||||
kwargs["add_bos_token"] = kwargs.get("add_bos_token", None)
|
||||
kwargs["add_eos_token"] = kwargs.get("add_eos_token", None)
|
||||
if not token:
|
||||
continue
|
||||
if idx == 0 and kwargs.get("add_special_tokens", None) is True:
|
||||
kwargs["add_bos_token"] = getattr(self, "add_bos_token", None)
|
||||
elif idx == len(tokens) - 1 and kwargs.get("add_special_tokens", None) is True:
|
||||
kwargs["add_eos_token"] = getattr(self, "add_eos_token", None)
|
||||
if token in no_split_token:
|
||||
tokenized_text.append(token)
|
||||
else:
|
||||
|
@ -985,6 +985,8 @@ class SpecialTokensMixin:
|
||||
|
||||
# if we are adding tokens that were not part of the vocab, we ought to add them
|
||||
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
|
||||
if hasattr(self, "update_post_processor"):
|
||||
self.update_post_processor()
|
||||
return added_tokens
|
||||
|
||||
def add_tokens(
|
||||
|
@ -174,7 +174,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if "add_bos_token" in kwargs or "add_eos_token" in kwargs:
|
||||
self._update_bos_eos_tokens()
|
||||
self.update_post_processor()
|
||||
|
||||
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
|
||||
self._tokenizer.encode_special_tokens = self.split_special_tokens
|
||||
@ -878,12 +878,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
|
||||
return self.__class__(tokenizer_object=tokenizer, **kwargs)
|
||||
|
||||
# Copied from LlamaTokenizerFast (update_post_processor, add_bos_token, add_eos_token
|
||||
def _update_bos_eos_tokens(self):
|
||||
def update_post_processor(self):
|
||||
"""
|
||||
Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
||||
"""
|
||||
if not isinstance(self._tokenizer.post_processor, processors.Sequence):
|
||||
if not isinstance(self._tokenizer.post_processor, processors.TemplateProcessing) and not isinstance(self._tokenizer.post_processor, processors.Sequence):
|
||||
return
|
||||
|
||||
bos = self.bos_token
|
||||
@ -919,9 +918,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
@add_eos_token.setter
|
||||
def add_eos_token(self, value):
|
||||
self._add_eos_token = value
|
||||
self._update_bos_eos_tokens()
|
||||
self.update_post_processor()
|
||||
|
||||
@add_bos_token.setter
|
||||
def add_bos_token(self, value):
|
||||
self._add_bos_token = value
|
||||
self._update_bos_eos_tokens()
|
||||
self.update_post_processor()
|
||||
|
@ -825,9 +825,7 @@ class CommonSpmIntegrationTests(unittest.TestCase):
|
||||
tokens = self.tokenizer.tokenize("No <s> ▁He")
|
||||
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
|
||||
|
||||
@unittest.skip("@require_read_token does not work? getting gated repo error")
|
||||
@require_read_token
|
||||
@unittest.skip
|
||||
def test_bos_eos_tokens(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", add_bos_token=False, add_eos_token=True)
|
||||
assert tokenizer("hello")["input_ids"][0] != tokenizer.bos_token_id # no bos token
|
||||
|
Loading…
Reference in New Issue
Block a user