This commit is contained in:
ydshieh 2025-07-02 12:25:43 +02:00
parent 50d48aaa8a
commit ca688e7449

View File

@ -16,7 +16,7 @@
import unittest
from transformers import MobileViTConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -304,7 +304,13 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-1.9401, -1.2384, -0.4702]).to(torch_device)
expectations = Expectations(
{
(None, None): [-1.9364, -1.2327, -0.4653],
("cuda", 8): [-1.9401, -1.2384, -0.4702],
}
)
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@ -327,14 +333,21 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 21, 32, 32))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[6.9661, 6.9753, 7.2386], [7.2864, 7.2785, 7.4429], [7.6577, 7.8770, 7.9387]],
[[-10.7046, -10.3411, -10.3641], [-10.4402, -10.0004, -9.7269], [-11.0579, -11.0358, -10.7459]],
[[-3.3022, -2.8465, -2.6661], [-3.2654, -2.5542, -2.5055], [-3.2477, -2.6544, -2.6562]],
],
device=torch_device,
expectations = Expectations(
{
(None, None): [
[[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]],
[[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]],
[[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]],
],
("cuda", 8): [
[[6.9661, 6.9753, 7.2386], [7.2864, 7.2785, 7.4429], [7.6577, 7.8770, 7.9387]],
[[-10.7046, -10.3411, -10.3641], [-10.4402, -10.0004, -9.7269], [-11.0579, -11.0358, -10.7459]],
[[-3.3022, -2.8465, -2.6661], [-3.2654, -2.5542, -2.5055], [-3.2477, -2.6544, -2.6562]],
],
}
)
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)