mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
implementing tflxmertmodel integration test (#12497)
* implementing tflxmertmodel integration test * move import * revert and fix
This commit is contained in:
parent
09af5bdea3
commit
3fd85777ea
@ -17,6 +17,8 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import LxmertConfig, is_tf_available
|
from transformers import LxmertConfig, is_tf_available
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
@ -555,8 +557,6 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
tf_hidden_states = tfo[0].numpy()
|
tf_hidden_states = tfo[0].numpy()
|
||||||
pt_hidden_states = pto[0].numpy()
|
pt_hidden_states = pto[0].numpy()
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||||
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||||
|
|
||||||
@ -768,3 +768,30 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
for attention, attention_shape in zip(attentions, attention_shapes):
|
for attention, attention_shape in zip(attentions, attention_shapes):
|
||||||
self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
|
self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFLxmertModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_inference_masked_lm(self):
|
||||||
|
model = TFLxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased")
|
||||||
|
input_ids = tf.constant([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
|
||||||
|
|
||||||
|
num_visual_features = 10
|
||||||
|
_, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim)
|
||||||
|
_, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
|
||||||
|
visual_feats = tf.convert_to_tensor(visual_feats, dtype=tf.float32)
|
||||||
|
visual_pos = tf.convert_to_tensor(visual_pos, dtype=tf.float32)
|
||||||
|
output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
|
||||||
|
expected_shape = [1, 11, 768]
|
||||||
|
self.assertEqual(expected_shape, output.shape)
|
||||||
|
expected_slice = tf.constant(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.24170142, -0.98075, 0.14797261],
|
||||||
|
[1.2540525, -0.83198136, 0.5112344],
|
||||||
|
[1.4070463, -1.1051831, 0.6990401],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
|
||||||
|
Loading…
Reference in New Issue
Block a user