mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix
This commit is contained in:
parent
6f59e7492c
commit
40cc1bfd3d
@ -635,7 +635,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.9982, 0.9960, 0.9955, 0.9988, 0.9987]).to(torch_device)
|
||||
expected_labels = [75, 75, 63, 17, 17]
|
||||
expected_slice_boxes = torch.tensor([40.1615, 70.8090, 175.5476, 117.9810]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([40.1615, 70.8090, 175.5476, 117.9810]).to(torch_device)
|
||||
|
||||
self.assertEqual(len(results["scores"]), 5)
|
||||
torch.testing.assert_close(results["scores"], expected_scores, rtol=2e-4, atol=2e-4)
|
||||
@ -669,22 +669,14 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
|
||||
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
|
||||
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.5344, 0.1790, 0.9284],
|
||||
[0.4421, 0.0571, 0.0875],
|
||||
[0.6632, 0.6886, 0.1015]
|
||||
]
|
||||
[[0.5344, 0.1790, 0.9284], [0.4421, 0.0571, 0.0875], [0.6632, 0.6886, 0.1015]]
|
||||
).to(torch_device)
|
||||
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, rtol=2e-4, atol=2e-4)
|
||||
|
||||
expected_shape_masks = torch.Size((1, model.config.num_queries, 200, 267))
|
||||
self.assertEqual(outputs.pred_masks.shape, expected_shape_masks)
|
||||
expected_slice_masks = torch.tensor(
|
||||
[
|
||||
[-7.8408, -11.0104, -12.1279],
|
||||
[-12.0299, -16.6498, -17.9806],
|
||||
[-14.8995, -19.9940, -20.5646]
|
||||
]
|
||||
[[-7.8408, -11.0104, -12.1279], [-12.0299, -16.6498, -17.9806], [-14.8995, -19.9940, -20.5646]]
|
||||
).to(torch_device)
|
||||
torch.testing.assert_close(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, rtol=2e-3, atol=2e-3)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user