mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
implementing tflxmertmodel integration test (#12497)
* implementing tflxmertmodel integration test * move import * revert and fix
This commit is contained in:
parent
09af5bdea3
commit
3fd85777ea
@ -17,6 +17,8 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import LxmertConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
@ -555,8 +557,6 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
|
||||
import numpy as np
|
||||
|
||||
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||
|
||||
@ -768,3 +768,30 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
for attention, attention_shape in zip(attentions, attention_shapes):
|
||||
self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFLxmertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFLxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased")
|
||||
input_ids = tf.constant([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
|
||||
|
||||
num_visual_features = 10
|
||||
_, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim)
|
||||
_, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
|
||||
visual_feats = tf.convert_to_tensor(visual_feats, dtype=tf.float32)
|
||||
visual_pos = tf.convert_to_tensor(visual_pos, dtype=tf.float32)
|
||||
output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
|
||||
expected_shape = [1, 11, 768]
|
||||
self.assertEqual(expected_shape, output.shape)
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[0.24170142, -0.98075, 0.14797261],
|
||||
[1.2540525, -0.83198136, 0.5112344],
|
||||
[1.4070463, -1.1051831, 0.6990401],
|
||||
]
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
Loading…
Reference in New Issue
Block a user