[Tests] Fix ViTMAE integration test (#15949)

* Fix test across both cpu and gpu

* Fix typo
This commit is contained in:
NielsRogge 2022-03-08 10:49:44 +01:00 committed by GitHub
parent 9879a1d5f0
commit b19f3e69a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -401,6 +401,9 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_for_pretraining(self):
# make random mask reproducible
# note that the same seed on CPU and on GPU doesnt mean they spew the same random number sequences,
# as they both have fairly different PRNGs (for efficiency reasons).
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
torch.manual_seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
@ -417,8 +420,14 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 196, 768))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor(
expected_slice_cpu = torch.tensor(
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]]
).to(torch_device)
)
expected_slice_gpu = torch.tensor(
[[0.8948, -1.0680, 0.0030], [0.9758, -1.1181, -0.0290], [1.0602, -1.1522, -0.0528]]
)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4))
# set expected slice depending on device
expected_slice = expected_slice_cpu if torch_device == "cpu" else expected_slice_gpu
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))