This commit is contained in:
ydshieh 2025-07-02 12:19:01 +02:00
parent a92786d77a
commit 1c407778e2

View File

@ -301,7 +301,13 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1001))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([0.2445, -1.1970, 0.1868]).to(torch_device)
expectations = Expectations(
{
(None, None): [0.2445, -1.1993, 0.1905],
("cuda", 8): [0.2445, -1.1970, 0.1868],
}
)
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@ -324,25 +330,32 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 21, 65, 65))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[
[17.5809, 17.7571, 18.3341],
[18.3240, 18.4216, 18.8974],
[18.6174, 18.8662, 19.2177],
expectations = Expectations(
{
(None, None): [
[[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]],
[[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]],
[[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]],
],
[
[-2.1562, -2.0942, -2.3703],
[-2.4199, -2.2999, -2.6818],
[-2.7800, -2.5944, -2.7678],
("cuda", 8): [
[
[17.5809, 17.7571, 18.3341],
[18.3240, 18.4216, 18.8974],
[18.6174, 18.8662, 19.2177],
],
[
[-2.1562, -2.0942, -2.3703],
[-2.4199, -2.2999, -2.6818],
[-2.7800, -2.5944, -2.7678],
],
[
[4.2092, 4.8356, 4.7694],
[4.4181, 5.0401, 4.9409],
[4.5089, 4.9700, 4.8802],
],
],
[
[4.2092, 4.8356, 4.7694],
[4.4181, 5.0401, 4.9409],
[4.5089, 4.9700, 4.8802],
],
],
device=torch_device,
}
)
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)