[tokenizers] Fix #5081 and improve backward compatibility (#5125)

* fix #5081 and improve backward compatibility (slightly)

* add nlp to setup.cfg - style and quality

* align default to previous default

* remove test that doesn't generalize
This commit is contained in:
Thomas Wolf 2020-06-22 17:25:43 +02:00 committed by GitHub
parent d2a7c86dc3
commit ebc36108dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 33 deletions

View File

@ -755,16 +755,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> str:
"""
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT

View File

@ -1774,6 +1774,51 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
return [self.decode(seq, **kwargs) for seq in sequences]
def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> str:
"""
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
raise NotImplementedError
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
assert already_has_special_tokens and token_ids_1 is None, (
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
"Please use a slow (full python) tokenizer to activate this argument."
"Or set `return_special_token_mask=True` when calling the encoding method "
"to get the special tokens mask in any tokenizer. "
)
all_special_ids = self.all_special_ids # cache the property
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
return special_tokens_mask
@staticmethod
def clean_up_tokenization(out_string: str) -> str:
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.

View File

@ -672,29 +672,6 @@ class TokenizerTesterMixin:
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
def test_special_tokens_mask_already_has_special_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
if not hasattr(tokenizer, "get_special_tokens_mask") or tokenizer.get_special_tokens_mask(
[0, 1, 2, 3]
) == [0, 0, 0, 0]:
continue
with self.subTest(f"{tokenizer.__class__.__name__}"):
sequence_0 = "Encode this."
if (
tokenizer.cls_token_id == tokenizer.unk_token_id
and tokenizer.cls_token_id == tokenizer.unk_token_id
):
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
)
# encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
min_val = min(special_tokens_mask_orig)
max_val = max(special_tokens_mask_orig)
self.assertNotEqual(min_val, max_val)
def test_right_and_left_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: