mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🚨 🚨 🚨 Fix Issue 15003: SentencePiece Tokenizers Not Adding Special Tokens in convert_tokens_to_string
(#15775)
* Add test for SentencePiece not adding special tokens to strings * Add SentencePieceStringConversionMixin to fix issue 15003 * Fix conversion from tokens to string for most SentencePiece tokenizers Tokenizers fixed: - AlbertTokenizer - BarthezTokenizer - CamembertTokenizer - FNetTokenizer - M2M100Tokenizer - MBart50Tokenizer - PegasusTokenizer - Speech2TextTokenizer * Fix MarianTokenizer, adjust SentencePiece test to accomodate vocab * Fix DebertaV2Tokenizer * Ignore LayoutXLMTokenizer in SentencePiece string conversion test * Run 'make style' and 'make quality' * Clean convert_tokens_to_string test Instead of explicitly ignoring LayoutXLMTokenizer in the test, override the test in LayoutLMTokenizationTest and do nothing in it. * Remove commented out code * Improve robustness of convert_tokens_to_string test Instead of comparing lengths of re-tokenized text and input_ids, check that converting all special tokens to string yields a string with all special tokens. * Inline and remove SentencePieceStringConversionMixin The convert_tokens_to_string method is now implemented in each relevant SentencePiece tokenizer. * Run 'make style' and 'make quality' * Revert removal of space in convert_tokens_to_string * Remove redundant import * Revert test text to original * Uncomment the lowercasing of the reverse_text variable * Mimic Rust tokenizer behavior for tokenizers - Albert - Barthez - Camembert - MBart50 - T5 * Fix accidentally skipping test in wrong tokenizer * Add test for equivalent Rust and slow tokenizer behavior * Override _decode in BigBirdTokenizer to mimic Rust behavior * Override _decode in FNetTokenizer to mimic Rust behavior * Override _decode in XLNetTokenizer to mimic Rust behavior * Remove unused 're' import * Update DebertaV2Tokenizer to mimic Rust tokenizer * Deberta tokenizer now behaves like Albert and its `convert_tokens_to_string` is not tested. * Ignore problematic tests in Deberta V2 * Add comment on why the Deberta V2 tests are skipped
This commit is contained in:
parent
fb7cbe236b
commit
9f9ddcc2de
@ -250,7 +250,23 @@ class AlbertTokenizer(PreTrainedTokenizer):
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.sp_model.decode(tokens)
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
|
@ -263,6 +263,25 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
@ -278,10 +297,6 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
|
@ -151,8 +151,17 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
out_string = self.sp_model.decode_pieces(tokens)
|
||||
return out_string
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@ -182,8 +183,65 @@ class BigBirdTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
out_string = self.sp_model.decode_pieces(tokens)
|
||||
return out_string
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||
# we need to build string separately for added tokens and byte-level tokens
|
||||
# cf. https://github.com/huggingface/transformers/issues/1133
|
||||
sub_texts = []
|
||||
current_sub_text = []
|
||||
for token in filtered_tokens:
|
||||
if skip_special_tokens and token in self.all_special_ids:
|
||||
continue
|
||||
if token in self.added_tokens_encoder:
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
|
||||
# Mimic the behavior of the Rust tokenizer:
|
||||
# No space before [MASK] and [SEP]
|
||||
if spaces_between_special_tokens:
|
||||
text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts))
|
||||
else:
|
||||
text = "".join(sub_texts)
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
|
@ -261,6 +261,25 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
@ -276,10 +295,6 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
|
@ -146,7 +146,9 @@ class DebertaV2Tokenizer(PreTrainedTokenizer):
|
||||
self.do_lower_case = do_lower_case
|
||||
self.split_by_punct = split_by_punct
|
||||
self.vocab_file = vocab_file
|
||||
self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs)
|
||||
self._tokenizer = SPMTokenizer(
|
||||
vocab_file, self.all_special_tokens, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
@ -291,7 +293,9 @@ class SPMTokenizer:
|
||||
BPE-dropout.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.split_by_punct = split_by_punct
|
||||
self.vocab_file = vocab_file
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
@ -312,6 +316,7 @@ class SPMTokenizer:
|
||||
# self.vocab['[UNK]'] = 3
|
||||
|
||||
self.spm = spm
|
||||
self.special_tokens = special_tokens
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
@ -339,7 +344,22 @@ class SPMTokenizer:
|
||||
|
||||
def decode(self, tokens, start=-1, end=-1, raw_text=None):
|
||||
if raw_text is None:
|
||||
return self.spm.decode_pieces([t for t in tokens])
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.spm.decode_pieces(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.spm.decode_pieces(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
else:
|
||||
words = self.split_to_words(raw_text)
|
||||
word_tokens = [self.tokenize(w) for w in words]
|
||||
|
@ -15,6 +15,7 @@
|
||||
""" Tokenization classes for FNet model."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@ -213,7 +214,66 @@ class FNetTokenizer(PreTrainedTokenizer):
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.sp_model.decode(tokens)
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||
# we need to build string separately for added tokens and byte-level tokens
|
||||
# cf. https://github.com/huggingface/transformers/issues/1133
|
||||
sub_texts = []
|
||||
current_sub_text = []
|
||||
for token in filtered_tokens:
|
||||
if skip_special_tokens and token in self.all_special_ids:
|
||||
continue
|
||||
if token in self.added_tokens_encoder:
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
|
||||
# Mimic the behavior of the Rust tokenizer:
|
||||
# No space after <unk>
|
||||
if spaces_between_special_tokens:
|
||||
text = re.sub(r"(<unk>) ", r"\1", " ".join(sub_texts))
|
||||
else:
|
||||
text = "".join(sub_texts)
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
|
@ -218,9 +218,19 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
||||
return self.id_to_lang_token[index]
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
|
@ -265,10 +265,18 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
|
||||
if self._decode_use_source_tokenizer:
|
||||
return self.spm_source.DecodePieces(tokens)
|
||||
else:
|
||||
return self.spm_target.DecodePieces(tokens)
|
||||
sp_model = self.spm_source if self._decode_use_source_tokenizer else self.spm_target
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
out_string += sp_model.decode_pieces(current_sub_tokens) + token + " "
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
out_string += sp_model.decode_pieces(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
|
@ -232,9 +232,24 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
|
@ -231,8 +231,17 @@ class PegasusTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
out_string = self.sp_model.decode_pieces(tokens)
|
||||
return out_string
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def num_special_tokens_to_add(self, pair=False):
|
||||
"""Just EOS"""
|
||||
|
@ -158,8 +158,17 @@ class ReformerTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
out_string = self.sp_model.decode_pieces(tokens)
|
||||
return out_string
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
|
@ -190,11 +190,19 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = self.sp_model.decode(tokens)
|
||||
|
||||
if self.do_upper_case:
|
||||
out_string = out_string.upper()
|
||||
return out_string
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
decoded = self.sp_model.decode(current_sub_tokens)
|
||||
out_string += (decoded.upper() if self.do_upper_case else decoded) + token + " "
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
decoded = self.sp_model.decode(current_sub_tokens)
|
||||
out_string += decoded.upper() if self.do_upper_case else decoded
|
||||
return out_string.strip()
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
|
@ -311,14 +311,19 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " "
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
out_string += self.sp_model.decode_pieces(current_sub_tokens)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
|
@ -250,6 +250,46 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
||||
return out_string
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||
# we need to build string separately for added tokens and byte-level tokens
|
||||
# cf. https://github.com/huggingface/transformers/issues/1133
|
||||
sub_texts = []
|
||||
current_sub_text = []
|
||||
for token in filtered_tokens:
|
||||
if skip_special_tokens and token in self.all_special_ids:
|
||||
continue
|
||||
if token in self.added_tokens_encoder:
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
|
||||
# Mimic the behavior of the Rust tokenizer:
|
||||
# By default, there are no spaces between special tokens
|
||||
text = "".join(sub_texts)
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
|
@ -37,7 +37,7 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
super().setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB)
|
||||
tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB, unk_token="<unk>")
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
@ -55,7 +55,6 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
||||
|
||||
self.assertEqual(vocab_keys[0], "<pad>")
|
||||
self.assertEqual(vocab_keys[1], "<unk>")
|
||||
self.assertEqual(vocab_keys[-1], "[PAD]")
|
||||
@ -80,6 +79,14 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(rust_tokens, tokens_target)
|
||||
|
||||
@unittest.skip("There is an inconsistency between slow and fast tokenizer due to a bug in the fast one.")
|
||||
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("There is an inconsistency between slow and fast tokenizer due to a bug in the fast one.")
|
||||
def test_sentencepiece_tokenize_and_decode(self):
|
||||
pass
|
||||
|
||||
def test_split_by_punct(self):
|
||||
# fmt: off
|
||||
sequence = "I was born in 92000, and this is falsé."
|
||||
|
@ -1946,3 +1946,11 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't support another framework than PyTorch")
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Doesn't use SentencePiece")
|
||||
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Doesn't use SentencePiece")
|
||||
def test_sentencepiece_tokenize_and_decode(self):
|
||||
pass
|
||||
|
@ -385,6 +385,33 @@ class TokenizerTesterMixin:
|
||||
|
||||
self.assertEqual(reverse_text, text)
|
||||
|
||||
special_tokens = tokenizer.all_special_tokens
|
||||
special_tokens_string = tokenizer.convert_tokens_to_string(special_tokens)
|
||||
for special_token in special_tokens:
|
||||
self.assertIn(special_token, special_tokens_string)
|
||||
|
||||
if self.test_rust_tokenizer:
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens)
|
||||
self.assertEqual(special_tokens_string, special_tokens_string_rust)
|
||||
|
||||
def test_sentencepiece_tokenize_and_decode(self):
|
||||
if not self.test_sentencepiece:
|
||||
return
|
||||
|
||||
text = "This is text to test the tokenizer."
|
||||
if self.test_rust_tokenizer:
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
slow_ids = tokenizer(text).input_ids
|
||||
fast_ids = rust_tokenizer(text).input_ids
|
||||
self.assertEqual(slow_ids, fast_ids)
|
||||
|
||||
slow_decoded = tokenizer.decode(slow_ids)
|
||||
fast_decoded = rust_tokenizer.decode(slow_ids)
|
||||
self.assertEqual(slow_decoded, fast_decoded)
|
||||
|
||||
def test_subword_regularization_tokenizer(self) -> None:
|
||||
if not self.test_sentencepiece:
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user