cleaning after rebase

This commit is contained in:
Ita Zaporozhets 2024-07-29 16:09:41 +02:00
parent e8bfe9051f
commit f80a9fd479
4 changed files with 8 additions and 15 deletions

View File

@ -685,16 +685,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
)
# ["This is something", "<special_token_1>", "else"]
tokenized_text = []
for idx, token in enumerate(tokens):
for token in tokens:
# Need to skip eventual empty (fully stripped) tokens
kwargs["add_bos_token"] = kwargs.get("add_bos_token", None)
kwargs["add_eos_token"] = kwargs.get("add_eos_token", None)
if not token:
continue
if idx == 0 and kwargs.get("add_special_tokens", None) is True:
kwargs["add_bos_token"] = getattr(self, "add_bos_token", None)
elif idx == len(tokens) - 1 and kwargs.get("add_special_tokens", None) is True:
kwargs["add_eos_token"] = getattr(self, "add_eos_token", None)
if token in no_split_token:
tokenized_text.append(token)
else:

View File

@ -985,6 +985,8 @@ class SpecialTokensMixin:
# if we are adding tokens that were not part of the vocab, we ought to add them
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
if hasattr(self, "update_post_processor"):
self.update_post_processor()
return added_tokens
def add_tokens(

View File

@ -174,7 +174,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
super().__init__(**kwargs)
if "add_bos_token" in kwargs or "add_eos_token" in kwargs:
self._update_bos_eos_tokens()
self.update_post_processor()
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
self._tokenizer.encode_special_tokens = self.split_special_tokens
@ -878,12 +878,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return self.__class__(tokenizer_object=tokenizer, **kwargs)
# Copied from LlamaTokenizerFast (update_post_processor, add_bos_token, add_eos_token
def _update_bos_eos_tokens(self):
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
if not isinstance(self._tokenizer.post_processor, processors.Sequence):
if not isinstance(self._tokenizer.post_processor, processors.TemplateProcessing) and not isinstance(self._tokenizer.post_processor, processors.Sequence):
return
bos = self.bos_token
@ -919,9 +918,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self._update_bos_eos_tokens()
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self._update_bos_eos_tokens()
self.update_post_processor()

View File

@ -825,9 +825,7 @@ class CommonSpmIntegrationTests(unittest.TestCase):
tokens = self.tokenizer.tokenize("No <s> ▁He")
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
@unittest.skip("@require_read_token does not work? getting gated repo error")
@require_read_token
@unittest.skip
def test_bos_eos_tokens(self):
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", add_bos_token=False, add_eos_token=True)
assert tokenizer("hello")["input_ids"][0] != tokenizer.bos_token_id # no bos token