Integration test for mobilebert (#9978)

This commit is contained in:
sandip 2021-02-03 22:06:45 +05:30 committed by GitHub
parent 1486205d23
commit ce08043f7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -319,3 +319,26 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
for model_name in ["google/mobilebert-uncased"]:
model = TFMobileBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_tf
class TFMobileBertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = TFMobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
expected_shape = [1, 6, 30522]
self.assertEqual(output.shape, expected_shape)
expected_slice = tf.constant(
[
[
[-4.5919547, -9.248295, -9.645256],
[-6.7306175, -6.440284, -6.6052837],
[-7.2743506, -6.7847915, -6.024673],
]
]
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)