mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
add warning to let the user know that the __call__
method is faster than encode
+ pad
for a fast tokenizer (#18693)
* add warning to let the user know that the method is slower that for a fast tokenizer * user warnings * fix layoutlmv2 * fix layout* * change warnings into logger.warning
This commit is contained in:
parent
dcff504e18
commit
6667b0d7bf
@ -2821,7 +2821,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
in the batch.
|
||||
|
||||
Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`,
|
||||
`self.pad_token_id` and `self.pad_token_type_id`)
|
||||
`self.pad_token_id` and `self.pad_token_type_id`).
|
||||
|
||||
Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the
|
||||
text followed by a call to the `pad` method to get a padded encoding.
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -2871,6 +2874,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
verbose (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to print more information and warnings.
|
||||
"""
|
||||
if self.__class__.__name__.endswith("Fast"):
|
||||
if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False):
|
||||
logger.warning_advice(
|
||||
f"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer,"
|
||||
" using the `__call__` method is faster than using a method to encode the text followed by a call"
|
||||
" to the `pad` method to get a padded encoding."
|
||||
)
|
||||
self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||
|
||||
# If we have a list of dicts, let's convert it in a dict of lists
|
||||
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
||||
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
|
||||
|
@ -21,7 +21,14 @@ import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from transformers import AddedToken, LayoutLMv2TokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
LayoutLMv2TokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BasicTokenizer,
|
||||
@ -41,6 +48,9 @@ from ...test_tokenization_common import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@ -788,6 +798,49 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
def test_padding_warning_message_fast_tokenizer(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
words, boxes = self.get_words_and_boxes_batch()
|
||||
|
||||
tokenizer_fast = self.get_rust_tokenizer()
|
||||
|
||||
encoding_fast = tokenizer_fast(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
tokenizer_fast.pad(encoding_fast)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to"
|
||||
" encode the text followed by a call to the `pad` method to get a padded encoding.",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
if not self.test_slow_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer_slow = self.get_tokenizer()
|
||||
|
||||
encoding_slow = tokenizer_slow(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
# We want to assert there are no warnings, but the 'assertLogs' method does not support that.
|
||||
# Therefore, we are adding a dummy warning, and then we will assert it is the only warning.
|
||||
logger.warning("Dummy warning")
|
||||
tokenizer_slow.pad(encoding_slow)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Dummy warning",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
|
@ -22,13 +22,23 @@ import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from transformers import AddedToken, LayoutLMv3TokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
LayoutLMv3TokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES, LayoutLMv3Tokenizer
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
|
||||
|
||||
from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin, merge_model_tokenizer_mappings
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@ -668,6 +678,49 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
def test_padding_warning_message_fast_tokenizer(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
words, boxes = self.get_words_and_boxes_batch()
|
||||
|
||||
tokenizer_fast = self.get_rust_tokenizer()
|
||||
|
||||
encoding_fast = tokenizer_fast(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
tokenizer_fast.pad(encoding_fast)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to"
|
||||
" encode the text followed by a call to the `pad` method to get a padded encoding.",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
if not self.test_slow_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer_slow = self.get_tokenizer()
|
||||
|
||||
encoding_slow = tokenizer_slow(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
# We want to assert there are no warnings, but the 'assertLogs' method does not support that.
|
||||
# Therefore, we are adding a dummy warning, and then we will assert it is the only warning.
|
||||
logger.warning("Dummy warning")
|
||||
tokenizer_slow.pad(encoding_slow)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Dummy warning",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
|
@ -19,7 +19,14 @@ import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from transformers import AddedToken, LayoutXLMTokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
LayoutXLMTokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.models.layoutxlm.tokenization_layoutxlm import LayoutXLMTokenizer
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
@ -40,6 +47,7 @@ from ...test_tokenization_common import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@ -697,6 +705,49 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
def test_padding_warning_message_fast_tokenizer(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
words, boxes = self.get_words_and_boxes_batch()
|
||||
|
||||
tokenizer_fast = self.get_rust_tokenizer()
|
||||
|
||||
encoding_fast = tokenizer_fast(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
tokenizer_fast.pad(encoding_fast)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to"
|
||||
" encode the text followed by a call to the `pad` method to get a padded encoding.",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
if not self.test_slow_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer_slow = self.get_tokenizer()
|
||||
|
||||
encoding_slow = tokenizer_slow(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
# We want to assert there are no warnings, but the 'assertLogs' method does not support that.
|
||||
# Therefore, we are adding a dummy warning, and then we will assert it is the only warning.
|
||||
logger.warning("Dummy warning")
|
||||
tokenizer_slow.pad(encoding_slow)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Dummy warning",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
|
@ -48,6 +48,7 @@ from transformers import (
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
@ -81,6 +82,8 @@ if is_tokenizers_available():
|
||||
from test_module.custom_tokenization_fast import CustomTokenizerFast
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||
|
||||
SMALL_TRAINING_CORPUS = [
|
||||
@ -1834,6 +1837,47 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(attention_mask + [0] * padding_size, right_padded_attention_mask)
|
||||
self.assertEqual([0] * padding_size + attention_mask, left_padded_attention_mask)
|
||||
|
||||
def test_padding_warning_message_fast_tokenizer(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
sequence = "This is a text"
|
||||
|
||||
tokenizer_fast = self.get_rust_tokenizer()
|
||||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||||
self._check_no_pad_token_padding(tokenizer_fast, sequence)
|
||||
|
||||
encoding_fast = tokenizer_fast(sequence)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
tokenizer_fast.pad(encoding_fast)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to"
|
||||
" encode the text followed by a call to the `pad` method to get a padded encoding.",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
if not self.test_slow_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer_slow = self.get_tokenizer()
|
||||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||||
self._check_no_pad_token_padding(tokenizer_slow, sequence)
|
||||
|
||||
encoding_slow = tokenizer_slow(sequence)
|
||||
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
# We want to assert there are no warnings, but the 'assertLogs' method does not support that.
|
||||
# Therefore, we are adding a dummy warning, and then we will assert it is the only warning.
|
||||
logger.warning("Dummy warning")
|
||||
tokenizer_slow.pad(encoding_slow)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertIn(
|
||||
"Dummy warning",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
def test_separate_tokenizers(self):
|
||||
# This tests that tokenizers don't impact others. Unfortunately the case where it fails is when
|
||||
# we're loading an S3 configuration from a pre-trained identifier, and we have no way of testing those today.
|
||||
|
Loading…
Reference in New Issue
Block a user