mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Warning about too long input for fast tokenizers too (#8799)
* Warning about too long input for fast tokenizers too If truncation is not set in tokenizers, but the tokenization is too long for the model (`model_max_length`), we used to trigger a warning that The input would probably fail (which it most likely will). This PR re-enables the warning for fast tokenizers too and uses common code for the trigger to make sure it's consistent across. * Checking for pair of inputs too. * Making the function private and adding it's doc. * Remove formatting ?? in odd place. * Missed uppercase.
This commit is contained in:
parent
f6b44e6190
commit
a8c3f9aa76
@ -2866,14 +2866,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
||||
|
||||
# Check lengths
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
|
||||
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length)
|
||||
)
|
||||
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
||||
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
||||
|
||||
# Padding
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
||||
@ -3204,3 +3197,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
.replace(" 're", "'re")
|
||||
)
|
||||
return out_string
|
||||
|
||||
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
|
||||
"""
|
||||
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
|
||||
corresponding model
|
||||
|
||||
Args:
|
||||
ids (:obj:`List[str]`): The ids produced by the tokenization
|
||||
max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set)
|
||||
verbose (:obj:`bool`): Whether or not to print more information and warnings.
|
||||
|
||||
"""
|
||||
if max_length is None and len(ids) > self.model_max_length and verbose:
|
||||
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.model_max_length)
|
||||
)
|
||||
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
||||
|
@ -418,6 +418,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
overflow_to_sample_mapping += [i] * len(toks["input_ids"])
|
||||
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
|
||||
|
||||
for input_ids in sanitized_tokens["input_ids"]:
|
||||
self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
|
||||
return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
|
||||
|
||||
def _encode_plus(
|
||||
@ -474,6 +476,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
batched_output.encodings,
|
||||
)
|
||||
|
||||
self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
|
||||
|
||||
return batched_output
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
|
@ -666,11 +666,28 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||||
|
||||
# Simple with no truncation
|
||||
output = tokenizer(seq_1, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
# Reset warnings
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer(seq_1, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
output = tokenizer([seq_1], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer([seq_1], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
# Overflowing tokens
|
||||
stride = 2
|
||||
@ -770,11 +787,28 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||||
|
||||
# Simple with no truncation
|
||||
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
# Reset warnings
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
|
||||
seq_1, add_special_tokens=False
|
||||
|
Loading…
Reference in New Issue
Block a user