This commit is contained in:
ydshieh 2025-07-01 20:17:58 +02:00
parent 6f59e7492c
commit 40cc1bfd3d

View File

@ -635,7 +635,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
)[0]
expected_scores = torch.tensor([0.9982, 0.9960, 0.9955, 0.9988, 0.9987]).to(torch_device)
expected_labels = [75, 75, 63, 17, 17]
expected_slice_boxes = torch.tensor([40.1615, 70.8090, 175.5476, 117.9810]).to(torch_device)
expected_slice_boxes = torch.tensor([40.1615, 70.8090, 175.5476, 117.9810]).to(torch_device)
self.assertEqual(len(results["scores"]), 5)
torch.testing.assert_close(results["scores"], expected_scores, rtol=2e-4, atol=2e-4)
@ -669,22 +669,14 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[
[0.5344, 0.1790, 0.9284],
[0.4421, 0.0571, 0.0875],
[0.6632, 0.6886, 0.1015]
]
[[0.5344, 0.1790, 0.9284], [0.4421, 0.0571, 0.0875], [0.6632, 0.6886, 0.1015]]
).to(torch_device)
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, rtol=2e-4, atol=2e-4)
expected_shape_masks = torch.Size((1, model.config.num_queries, 200, 267))
self.assertEqual(outputs.pred_masks.shape, expected_shape_masks)
expected_slice_masks = torch.tensor(
[
[-7.8408, -11.0104, -12.1279],
[-12.0299, -16.6498, -17.9806],
[-14.8995, -19.9940, -20.5646]
]
[[-7.8408, -11.0104, -12.1279], [-12.0299, -16.6498, -17.9806], [-14.8995, -19.9940, -20.5646]]
).to(torch_device)
torch.testing.assert_close(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, rtol=2e-3, atol=2e-3)