mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix
This commit is contained in:
parent
50d48aaa8a
commit
ca688e7449
@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import MobileViTConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -304,7 +304,13 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-1.9401, -1.2384, -0.4702]).to(torch_device)
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [-1.9364, -1.2327, -0.4653],
|
||||
("cuda", 8): [-1.9401, -1.2384, -0.4702],
|
||||
}
|
||||
)
|
||||
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||
|
||||
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
|
||||
|
||||
@ -327,14 +333,21 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((1, 21, 32, 32))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[[6.9661, 6.9753, 7.2386], [7.2864, 7.2785, 7.4429], [7.6577, 7.8770, 7.9387]],
|
||||
[[-10.7046, -10.3411, -10.3641], [-10.4402, -10.0004, -9.7269], [-11.0579, -11.0358, -10.7459]],
|
||||
[[-3.3022, -2.8465, -2.6661], [-3.2654, -2.5542, -2.5055], [-3.2477, -2.6544, -2.6562]],
|
||||
],
|
||||
device=torch_device,
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [
|
||||
[[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]],
|
||||
[[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]],
|
||||
[[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]],
|
||||
],
|
||||
("cuda", 8): [
|
||||
[[6.9661, 6.9753, 7.2386], [7.2864, 7.2785, 7.4429], [7.6577, 7.8770, 7.9387]],
|
||||
[[-10.7046, -10.3411, -10.3641], [-10.4402, -10.0004, -9.7269], [-11.0579, -11.0358, -10.7459]],
|
||||
[[-3.3022, -2.8465, -2.6661], [-3.2654, -2.5542, -2.5055], [-3.2477, -2.6544, -2.6562]],
|
||||
],
|
||||
}
|
||||
)
|
||||
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||
|
||||
torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user