This commit is contained in:
ydshieh 2025-07-02 07:53:19 +02:00
parent 8ebb6de590
commit bec9d4fbab

View File

@ -301,9 +301,9 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1001))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([0.2445, -1.1993, 0.1905]).to(torch_device)
expected_slice = torch.tensor([ 0.2445, -1.1970, 0.1868]).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_inference_semantic_segmentation(self):
@ -326,11 +326,23 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor(
[
[[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]],
[[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]],
[[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]],
[
[17.5809, 17.7571, 18.3341],
[18.3240, 18.4216, 18.8974],
[18.6174, 18.8662, 19.2177],
],
[
[-2.1562, -2.0942, -2.3703],
[-2.4199, -2.2999, -2.6818],
[-2.7800, -2.5944, -2.7678],
],
[
[4.2092, 4.8356, 4.7694],
[4.4181, 5.0401, 4.9409],
[4.5089, 4.9700, 4.8802],
],
],
device=torch_device,
)
torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)