mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Tests] Fix ViTMAE integration test (#15949)
* Fix test across both cpu and gpu * Fix typo
This commit is contained in:
parent
9879a1d5f0
commit
b19f3e69a0
@ -401,6 +401,9 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
# make random mask reproducible
|
||||
# note that the same seed on CPU and on GPU doesn’t mean they spew the same random number sequences,
|
||||
# as they both have fairly different PRNGs (for efficiency reasons).
|
||||
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
|
||||
torch.manual_seed(2)
|
||||
|
||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
@ -417,8 +420,14 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((1, 196, 768))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
expected_slice_cpu = torch.tensor(
|
||||
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]]
|
||||
).to(torch_device)
|
||||
)
|
||||
expected_slice_gpu = torch.tensor(
|
||||
[[0.8948, -1.0680, 0.0030], [0.9758, -1.1181, -0.0290], [1.0602, -1.1522, -0.0528]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4))
|
||||
# set expected slice depending on device
|
||||
expected_slice = expected_slice_cpu if torch_device == "cpu" else expected_slice_gpu
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user