Fix LayoutLMv2 test (#15939)

* Fix LayoutLMv2 test

* Update black
This commit is contained in:
NielsRogge 2022-03-08 10:49:30 +01:00 committed by GitHub
parent 8b9ae45549
commit 9879a1d5f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -31,14 +31,7 @@ from transformers.models.layoutlmv2.tokenization_layoutlmv2 import (
_is_punctuation,
_is_whitespace,
)
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_pandas,
require_scatter,
require_tokenizers,
require_torch,
slow,
)
from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
from ..test_tokenization_common import (
SMALL_TRAINING_CORPUS,
@ -1219,7 +1212,6 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@require_torch
@slow
@require_scatter
def test_torch_encode_plus_sent_to_model(self):
import torch
@ -1254,10 +1246,15 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
words, boxes = self.get_words_and_boxes()
encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_tensors="pt")
batch_encoded_sequence = tokenizer.batch_encode_plus(
[words, words], [boxes, boxes], return_tensors="pt"
[words, words], boxes=[boxes, boxes], return_tensors="pt"
)
# This should not fail
# We add dummy image keys (as LayoutLMv2 actually also requires a feature extractor
# to prepare the image input)
encoded_sequence["image"] = torch.randn(1, 3, 224, 224)
batch_encoded_sequence["image"] = torch.randn(2, 3, 224, 224)
# This should not fail
with torch.no_grad(): # saves some time
model(**encoded_sequence)
model(**batch_encoded_sequence)