fix retribert's test_torch_encode_plus_sent_to_model (#17231)

This commit is contained in:
SaulLu 2022-05-17 14:33:13 +02:00 committed by GitHub
parent ec7f8af106
commit 6d211429ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
_is_punctuation,
_is_whitespace,
)
from transformers.testing_utils import require_tokenizers, slow
from transformers.testing_utils import require_tokenizers, require_torch, slow
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
]
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
@require_torch
@slow
def test_torch_encode_plus_sent_to_model(self):
import torch
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
return
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
config = config_class()
if config.is_encoder_decoder or config.pad_token_id is None:
return
model = model_class(config)
# The following test is different from the common's one
self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
# Build sequence
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
# Ensure that the BatchEncoding.to() method works.
encoded_sequence.to(model.device)
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail
with torch.no_grad(): # saves some time
# The following lines are different from the common's ones
model.embed_questions(**encoded_sequence)
model.embed_questions(**batch_encoded_sequence)