diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index bca3158e083..064ed6219c8 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -21,6 +21,7 @@ from tests.test_modeling_common import floats_tensor from transformers import AutoModelForImageClassification, Mask2FormerConfig, is_torch_available, is_vision_available from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( + Expectations, require_timm, require_torch, require_torch_accelerator, @@ -451,13 +452,21 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): rtol=TOLERANCE, ) - expected_slice_hidden_state = torch.tensor( - [ - [0.8974, 1.1848, 1.1777], - [1.1933, 1.5041, 1.5128], - [1.1154, 1.4487, 1.4950], - ] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [0.8973, 1.1847, 1.1776], + [1.1934, 1.5040, 1.5128], + [1.1153, 1.4486, 1.4951], + ], + ("cuda", 8): [ + [0.8974, 1.1848, 1.1777], + [1.1933, 1.5041, 1.5128], + [1.1154, 1.4487, 1.4950], + ], + } + ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close( outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, @@ -467,11 +476,24 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): expected_slice_hidden_state = torch.tensor( [ - [2.1153, 1.7004, -0.8604], - [1.5807, 1.8007, -0.9354], - [1.6040, 1.7498, -0.6001], + ] ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [2.1152, 1.7000, -0.8603], + [1.5808, 1.8004, -0.9353], + [1.6043, 1.7495, -0.5999], + ], + ("cuda", 8): [ + [2.1153, 1.7004, -0.8604], + [1.5807, 1.8007, -0.9354], + [1.6040, 1.7498, -0.6001], + ], + } + ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close( outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state,