From 6d211429ec0ee281e2b9552246bd200e5e299626 Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Tue, 17 May 2022 14:33:13 +0200 Subject: [PATCH] fix retribert's `test_torch_encode_plus_sent_to_model` (#17231) --- .../retribert/test_tokenization_retribert.py | 48 ++++++++++++++++++- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/tests/models/retribert/test_tokenization_retribert.py b/tests/models/retribert/test_tokenization_retribert.py index e6511bdbb7c..e2bf4e61b1a 100644 --- a/tests/models/retribert/test_tokenization_retribert.py +++ b/tests/models/retribert/test_tokenization_retribert.py @@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import ( _is_punctuation, _is_whitespace, ) -from transformers.testing_utils import require_tokenizers, slow +from transformers.testing_utils import require_tokenizers, require_torch, slow -from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english +from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings # Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert @@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ] self.assertListEqual(tokens_without_spe_char_p, expected_tokens) self.assertListEqual(tokens_without_spe_char_r, expected_tokens) + + # RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input + @require_torch + @slow + def test_torch_encode_plus_sent_to_model(self): + import torch + + from transformers import MODEL_MAPPING, TOKENIZER_MAPPING + + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING) + + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: + return + + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] + config = config_class() + + if config.is_encoder_decoder or config.pad_token_id is None: + return + + model = model_class(config) + + # The following test is different from the common's one + self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer)) + + # Build sequence + first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] + sequence = " ".join(first_ten_tokens) + encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt") + + # Ensure that the BatchEncoding.to() method works. + encoded_sequence.to(model.device) + + batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") + # This should not fail + + with torch.no_grad(): # saves some time + # The following lines are different from the common's ones + model.embed_questions(**encoded_sequence) + model.embed_questions(**batch_encoded_sequence)