This commit is contained in:
ydshieh 2025-07-02 12:30:30 +02:00
parent ca688e7449
commit 571aa68422

View File

@ -16,7 +16,14 @@
import unittest
from transformers import MobileViTV2Config
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
Expectations,
require_torch,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -317,7 +324,13 @@ class MobileViTV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-1.6341, -0.0665, -0.5158]).to(torch_device)
expectations = Expectations(
{
(None, None): [-1.6336e00, -7.3204e-02, -5.1883e-01],
("cuda", 8): [-1.6341, -0.0665, -0.5158],
}
)
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@ -340,14 +353,21 @@ class MobileViTV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 21, 32, 32))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[
expectations = Expectations(
{
(None, None): [
[[7.0863, 7.1525, 6.8201], [6.6931, 6.8770, 6.8933], [6.2978, 7.0366, 6.9636]],
[[-3.7134, -3.6712, -3.6675], [-3.5825, -3.3549, -3.4777], [-3.3435, -3.3979, -3.2857]],
[[-2.9329, -2.8003, -2.7369], [-3.0564, -2.4780, -2.0207], [-2.6889, -1.9298, -1.7640]],
],
("cuda", 8): [
[[7.0866, 7.1509, 6.8188], [6.6935, 6.8757, 6.8927], [6.2988, 7.0365, 6.9631]],
[[-3.7113, -3.6686, -3.6643], [-3.5801, -3.3516, -3.4739], [-3.3432, -3.3966, -3.2832]],
[[-2.9359, -2.8037, -2.7387], [-3.0595, -2.4798, -2.0222], [-2.6901, -1.9306, -1.7659]],
],
device=torch_device,
}
)
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)