diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 76c2c151fdf..0da576be742 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2733,11 +2733,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): first_element = required_input[0] if isinstance(first_element, (list, tuple)): # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. - index = 0 - while len(required_input[index]) == 0: - index += 1 - if index < len(required_input): - first_element = required_input[index][0] + for item in required_input: + if len(item) != 0: + first_element = item[0] + break # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): if is_tf_available() and _is_tensorflow(first_element): diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 685ccf43af1..e11f0d7150b 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1393,7 +1393,7 @@ class TokenizerTesterMixin: assert encoded_sequence == padded_sequence_left def test_padding_to_max_length(self): - """We keep this test for backward compatibility but it should be remove when `pad_to_max_length` will e deprecated""" + """We keep this test for backward compatibility but it should be remove when `pad_to_max_length` is deprecated.""" tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): @@ -2997,8 +2997,8 @@ class TokenizerTesterMixin: input_r = tokenizer_r.encode_plus("This is a input 1") input_r = tokenizer_r.pad(input_r) - input_p = tokenizer_r.encode_plus("This is a input 1") - input_p = tokenizer_r.pad(input_p) + input_p = tokenizer_p.encode_plus("This is a input 1") + input_p = tokenizer_p.pad(input_p) self.assert_padded_input_match( input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id @@ -3008,8 +3008,8 @@ class TokenizerTesterMixin: input_r = tokenizer_r.encode_plus("This is a input 1") input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length") - input_p = tokenizer_r.encode_plus("This is a input 1") - input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length") + input_p = tokenizer_p.encode_plus("This is a input 1") + input_p = tokenizer_p.pad(input_p, max_length=max_length, padding="max_length") self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) @@ -3019,10 +3019,10 @@ class TokenizerTesterMixin: ) input_r = tokenizer_r.pad(input_r) - input_p = tokenizer_r.batch_encode_plus( + input_p = tokenizer_p.batch_encode_plus( ["This is a input 1", "This is a much longer input whilch should be padded"] ) - input_p = tokenizer_r.pad(input_p) + input_p = tokenizer_p.pad(input_p) self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) @@ -3032,11 +3032,15 @@ class TokenizerTesterMixin: ) input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length") - input_p = tokenizer_r.batch_encode_plus( + input_p = tokenizer_p.batch_encode_plus( ["This is a input 1", "This is a much longer input whilch should be padded"] ) - input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length") + input_p = tokenizer_p.pad(input_p, max_length=max_length, padding="max_length") + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + # Test padding nested empty lists (in some use-cases, there is no any token id in the `input_ids` list). + input_r = tokenizer_r.pad({"input_ids": [[], []]}, max_length=max_length, padding="max_length") + input_p = tokenizer_p.pad({"input_ids": [[], []]}, max_length=max_length, padding="max_length") self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) def test_padding_different_model_input_name(self):