mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
create LxmertModelIntegrationTest Pytorch (#9989)
* create LxmertModelIntegrationTest * implementation using numpy seeding to fix inputs params. * fix code quality * isort check
This commit is contained in:
parent
23ab0b6980
commit
0e1718afb6
@ -17,6 +17,8 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
@ -727,3 +729,24 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(attentions_vision.grad)
|
||||
self.assertIsNotNone(hidden_states_vision.grad)
|
||||
self.assertIsNotNone(attentions_vision.grad)
|
||||
|
||||
|
||||
@require_torch
|
||||
class LxmertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = LxmertModel.from_pretrained(LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
|
||||
input_ids = torch.tensor([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
|
||||
num_visual_features = 10
|
||||
_, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, LxmertModel.config.visual_feat_dim)
|
||||
_, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
|
||||
visual_feats = torch.as_tensor(visual_feats, dtype=torch.float32)
|
||||
visual_pos = torch.as_tensor(visual_pos, dtype=torch.float32)
|
||||
output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
|
||||
expected_shape = torch.Size([1, 11, 768])
|
||||
self.assertEqual(expected_shape, output.shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.2417, -0.9807, 0.1480], [1.2541, -0.8320, 0.5112], [1.4070, -1.1052, 0.6990]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user