skip if post_processor is not type Sequence (do not support TemplateProcessor)

This commit is contained in:
Ita Zaporozhets 2024-06-20 17:29:28 +02:00
parent d77e5ea7ff
commit a76d2b74bb
2 changed files with 14 additions and 4 deletions

View File

@ -685,10 +685,16 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
)
# ["This is something", "<special_token_1>", "else"]
tokenized_text = []
for token in tokens:
for idx, token in enumerate(tokens):
# Need to skip eventual empty (fully stripped) tokens
kwargs["add_bos_token"] = kwargs.get("add_bos_token", None)
kwargs["add_eos_token"] = kwargs.get("add_eos_token", None)
if not token:
continue
if idx == 0 and kwargs.get("add_special_tokens", None) is True:
kwargs["add_bos_token"] = getattr(self, "add_bos_token", None)
elif idx == len(tokens) - 1 and kwargs.get("add_special_tokens", None) is True:
kwargs["add_eos_token"] = getattr(self, "add_eos_token", None)
if token in no_split_token:
tokenized_text.append(token)
else:
@ -1025,10 +1031,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
@overload
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str:
...
@overload
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ...
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:
...
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False

View File

@ -879,6 +879,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
if not isinstance(self._tokenizer.post_processor, processors.Sequence):
return
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
@ -918,4 +921,3 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def add_bos_token(self, value):
self._add_bos_token = value
self._update_bos_eos_tokens()