implementing tflxmertmodel integration test (#12497)

* implementing tflxmertmodel integration test

* move import

* revert and fix
This commit is contained in:
sadakmed 2021-07-06 17:44:47 +02:00 committed by GitHub
parent 09af5bdea3
commit 3fd85777ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,6 +17,8 @@ import os
import tempfile
import unittest
import numpy as np
from transformers import LxmertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
@ -555,8 +557,6 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
import numpy as np
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
@ -768,3 +768,30 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
for attention, attention_shape in zip(attentions, attention_shapes):
self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
@require_tf
class TFLxmertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = TFLxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased")
input_ids = tf.constant([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
num_visual_features = 10
_, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim)
_, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
visual_feats = tf.convert_to_tensor(visual_feats, dtype=tf.float32)
visual_pos = tf.convert_to_tensor(visual_pos, dtype=tf.float32)
output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
expected_shape = [1, 11, 768]
self.assertEqual(expected_shape, output.shape)
expected_slice = tf.constant(
[
[
[0.24170142, -0.98075, 0.14797261],
[1.2540525, -0.83198136, 0.5112344],
[1.4070463, -1.1051831, 0.6990401],
]
]
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)