Correct order of overflowing_tokens for slow tokenizer (#13179)

* correct order of overflowing_tokens for slow tokenizer (issue fix #13148)

* python 3.9 requires sentencepiece version 0.1.94 or above

* slicing of ids fixed in truncated_sequence()

* Update setup.py

* Correct order of overflowing tokens for pair of sentences

* code reformatted

* Update tokenization_utils_base.py

* reformatting file

* test to check single_input added

* missing function restored

* test to check pair_input overflowing tokens order

* test to check pair_input overflowing tokens order

* test to check pair_input overflowing tokens order

* added an error message for pair of seq and longest_first strategy

* test for pair_input modified

* variable name corrected

* fixed a typo in error message

* requested changes implemented

* required test added

* Corrected the message to match test message

* added error message for Luke Tokenizer

* lost test recovered

* docstring for truncate_sequences and prepare_for_model updated

* docstring for luke tokenizer updated

* updated ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING

* aligned text and fixed puncuatations

* improved style and quality of code

* fixed error_msg in truncate_sequences

* replaced encode_plus method with regular call method

* clean up

* rephrased the docstring
This commit is contained in:
Apoorv Garg 2021-09-02 15:28:23 +05:30 committed by GitHub
parent c9184a2e03
commit b91e65afe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 65 deletions

View File

@ -85,7 +85,9 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return overflowing token sequences. Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
of pairs) is provided with :obj:`truncation_strategy = longest_first` or :obj:`True`, an error is
raised instead of returning overflowing tokens.
return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return special tokens mask information. Whether or not to return special tokens mask information.
return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
@ -1037,8 +1039,9 @@ class LukeTokenizer(RobertaTokenizer):
Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids, Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids,
entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing
while taking into account the special tokens and manages a moving window (with user defined stride) for while taking into account the special tokens and manages a moving window (with user defined stride) for
overflowing tokens overflowing tokens. Please Note, for `pair_ids` different than `None` and `truncation_strategy = longest_first`
or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an
error.
Args: Args:
ids (:obj:`List[int]`): ids (:obj:`List[int]`):
@ -1078,6 +1081,16 @@ class LukeTokenizer(RobertaTokenizer):
"results in an undefined behavior. Please set add_special_tokens to True or " "results in an undefined behavior. Please set add_special_tokens to True or "
"set return_token_type_ids to None." "set return_token_type_ids to None."
) )
if (
return_overflowing_tokens
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
and pair_ids is not None
):
raise ValueError(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
# Load from model defaults # Load from model defaults
if return_token_type_ids is None: if return_token_type_ids is None:

View File

@ -1324,7 +1324,9 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return overflowing token sequences. Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
of pairs) is provided with :obj:`truncation_strategy = longest_first` or :obj:`True`, an error is
raised instead of returning overflowing tokens.
return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return special tokens mask information. Whether or not to return special tokens mask information.
return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
@ -2838,7 +2840,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens manages a moving window (with user defined stride) for overflowing tokens. Please Note, for `pair_ids`
different than `None` and `truncation_strategy = longest_first` or `True`, it is not possible to return
overflowing tokens. Such a combination of arguments will raise an error.
Args: Args:
ids (:obj:`List[int]`): ids (:obj:`List[int]`):
@ -2870,6 +2874,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"set return_token_type_ids to None." "set return_token_type_ids to None."
) )
if (
return_overflowing_tokens
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
and pair_ids is not None
):
raise ValueError(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
# Load from model defaults # Load from model defaults
if return_token_type_ids is None: if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names return_token_type_ids = "token_type_ids" in self.model_input_names
@ -2977,7 +2992,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
Returns: Returns:
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the :obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
list of overflowing tokens. list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing_tokens if
a pair of sequences (or a batch of pairs) is provided.
""" """
if num_tokens_to_remove <= 0: if num_tokens_to_remove <= 0:
return ids, pair_ids, [] return ids, pair_ids, []
@ -2986,34 +3002,36 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
truncation_strategy = TruncationStrategy(truncation_strategy) truncation_strategy = TruncationStrategy(truncation_strategy)
overflowing_tokens = [] overflowing_tokens = []
if truncation_strategy == TruncationStrategy.LONGEST_FIRST: if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
for _ in range(num_tokens_to_remove): truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
if pair_ids is None or len(ids) > len(pair_ids): ):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
ids = ids[:-1]
else:
if not overflowing_tokens:
window_len = min(len(pair_ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(pair_ids[-window_len:])
pair_ids = pair_ids[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
if len(ids) > num_tokens_to_remove: if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove) window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:] overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove] ids = ids[:-num_tokens_to_remove]
else: else:
logger.error( error_msg = (
f"We need to remove {num_tokens_to_remove} to truncate the input" f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the first sequence has a length {len(ids)}. " f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
) )
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
error_msg + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
f"truncation strategy. So the returned list will always be empty even if some "
f"tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
ids = ids[:-1]
else:
pair_ids = pair_ids[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove: if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove) window_len = min(len(pair_ids), stride + num_tokens_to_remove)

View File

@ -941,6 +941,7 @@ class TokenizerTesterMixin:
self.assertEqual(truncated_sequence, sequence[:-2]) self.assertEqual(truncated_sequence, sequence[:-2])
self.assertEqual(len(overflowing_tokens), 2 + stride) self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :])
def test_maximum_encoding_length_pair_input(self): def test_maximum_encoding_length_pair_input(self):
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100) tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
@ -1053,18 +1054,18 @@ class TokenizerTesterMixin:
overflow_first_sequence if len(seq0_tokens) > len(seq1_tokens) else overflow_second_sequence overflow_first_sequence if len(seq0_tokens) > len(seq1_tokens) else overflow_second_sequence
) )
information = tokenizer.encode_plus(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers # Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast): if isinstance(tokenizer, PreTrainedTokenizerFast):
information = tokenizer(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
truncated_sequence = information["input_ids"][0] truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1] overflowing_tokens = information["input_ids"][1]
self.assertEqual(len(information["input_ids"]), 2) self.assertEqual(len(information["input_ids"]), 2)
@ -1075,28 +1076,39 @@ class TokenizerTesterMixin:
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest)) self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
self.assertEqual(overflowing_tokens, overflow_longest_sequence) self.assertEqual(overflowing_tokens, overflow_longest_sequence)
else: else:
truncated_sequence = information["input_ids"] # No overflowing tokens when using 'longest' in python tokenizers
overflowing_tokens = information["overflowing_tokens"] with self.assertRaises(ValueError) as context:
information = tokenizer(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
self.assertEqual(len(truncated_sequence), len(sequence) - 2) self.assertTrue(
self.assertEqual(truncated_sequence, truncated_longest_sequence) context.exception.args[0].startswith(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
)
self.assertEqual(
len(overflowing_tokens), 2 + stride
) # No overflowing tokens when using 'longest' in python tokenizers
information = tokenizer.encode_plus(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers # Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast): if isinstance(tokenizer, PreTrainedTokenizerFast):
information = tokenizer(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
truncated_sequence = information["input_ids"][0] truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1] overflowing_tokens = information["input_ids"][1]
self.assertEqual(len(information["input_ids"]), 2) self.assertEqual(len(information["input_ids"]), 2)
@ -1107,17 +1119,28 @@ class TokenizerTesterMixin:
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest)) self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
self.assertEqual(overflowing_tokens, overflow_longest_sequence) self.assertEqual(overflowing_tokens, overflow_longest_sequence)
else: else:
truncated_sequence = information["input_ids"] # No overflowing tokens when using 'longest' in python tokenizers
overflowing_tokens = information["overflowing_tokens"] with self.assertRaises(ValueError) as context:
information = tokenizer(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
self.assertEqual(len(truncated_sequence), len(sequence) - 2) self.assertTrue(
self.assertEqual(truncated_sequence, truncated_longest_sequence) context.exception.args[0].startswith(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
)
self.assertEqual( information_first_truncated = tokenizer(
len(overflowing_tokens), 2 + stride
) # No overflowing tokens when using 'longest' in python tokenizers
information_first_truncated = tokenizer.encode_plus(
seq_0, seq_0,
seq_1, seq_1,
max_length=len(sequence) - 2, max_length=len(sequence) - 2,
@ -1148,7 +1171,7 @@ class TokenizerTesterMixin:
self.assertEqual(len(overflowing_tokens), 2 + stride) self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, seq0_tokens[-(2 + stride) :]) self.assertEqual(overflowing_tokens, seq0_tokens[-(2 + stride) :])
information_second_truncated = tokenizer.encode_plus( information_second_truncated = tokenizer(
seq_0, seq_0,
seq_1, seq_1,
max_length=len(sequence) - 2, max_length=len(sequence) - 2,