mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix retribert's test_torch_encode_plus_sent_to_model
(#17231)
This commit is contained in:
parent
ec7f8af106
commit
6d211429ec
@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
|
||||
_is_punctuation,
|
||||
_is_whitespace,
|
||||
)
|
||||
from transformers.testing_utils import require_tokenizers, slow
|
||||
from transformers.testing_utils import require_tokenizers, require_torch, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
|
||||
|
||||
|
||||
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
|
||||
@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
|
||||
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
|
||||
|
||||
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
|
||||
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
|
||||
|
||||
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
|
||||
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
|
||||
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
|
||||
return
|
||||
|
||||
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
|
||||
config = config_class()
|
||||
|
||||
if config.is_encoder_decoder or config.pad_token_id is None:
|
||||
return
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
# The following test is different from the common's one
|
||||
self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
|
||||
|
||||
# Build sequence
|
||||
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
|
||||
sequence = " ".join(first_ten_tokens)
|
||||
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
|
||||
|
||||
# Ensure that the BatchEncoding.to() method works.
|
||||
encoded_sequence.to(model.device)
|
||||
|
||||
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
|
||||
# This should not fail
|
||||
|
||||
with torch.no_grad(): # saves some time
|
||||
# The following lines are different from the common's ones
|
||||
model.embed_questions(**encoded_sequence)
|
||||
model.embed_questions(**batch_encoded_sequence)
|
||||
|
Loading…
Reference in New Issue
Block a user