Added Integration testing for DistilBert model from issue #9948' (#9995)

This commit is contained in:
Daniel Hug 2021-02-04 04:24:59 -05:00 committed by GitHub
parent 00031785a8
commit 804cd185d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,6 +24,8 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
from transformers import (
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
DistilBertConfig,
@ -246,3 +248,19 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DistilBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class DistilBertModelIntergrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = DistilBertModel.from_pretrained("distilbert-base-uncased")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 768))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[0.4026, -0.2919, 0.3902], [0.3828, -0.2129, 0.3563], [0.3919, -0.2287, 0.3438]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))