mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Integration test for mobilebert (#9978)
This commit is contained in:
parent
1486205d23
commit
ce08043f7a
@ -319,3 +319,26 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in ["google/mobilebert-uncased"]:
|
||||
model = TFMobileBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFMobileBertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFMobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
|
||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = [1, 6, 30522]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[-4.5919547, -9.248295, -9.645256],
|
||||
[-6.7306175, -6.440284, -6.6052837],
|
||||
[-7.2743506, -6.7847915, -6.024673],
|
||||
]
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
Loading…
Reference in New Issue
Block a user