diff --git a/tests/models/colqwen2/test_modeling_colqwen2.py b/tests/models/colqwen2/test_modeling_colqwen2.py index 80aae3e9cfd..3ed42afd99a 100644 --- a/tests/models/colqwen2/test_modeling_colqwen2.py +++ b/tests/models/colqwen2/test_modeling_colqwen2.py @@ -26,7 +26,15 @@ from transformers import is_torch_available from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2ForRetrievalOutput from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor -from transformers.testing_utils import cleanup, require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + Expectations, + cleanup, + require_bitsandbytes, + require_torch, + require_vision, + slow, + torch_device, +) if is_torch_available(): @@ -283,6 +291,7 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase): def tearDown(self): cleanup(torch_device, gc_collect=True) + @require_bitsandbytes @slow def test_model_integration_test(self): """ @@ -291,7 +300,7 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase): model = ColQwen2ForRetrieval.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, - device_map=torch_device, + load_in_8bit=True, ).eval() # Load the test dataset @@ -319,13 +328,20 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase): self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all()) # Further validation: fine-grained check, with a hardcoded score from the original Hf implementation. - expected_scores = torch.tensor( - [ - [16.2500, 7.8750, 14.6875], - [9.5000, 17.1250, 10.5000], - [14.9375, 10.9375, 20.0000], - ], - dtype=scores.dtype, + expectations = Expectations( + { + ("cuda", 7): [ + [15.5000, 8.1250, 14.9375], + [9.0625, 17.1250, 10.6875], + [15.9375, 12.1875, 20.2500], + ], + ("cuda", 8): [ + [15.1250, 8.6875, 15.0625], + [9.2500, 17.2500, 10.3750], + [15.9375, 12.3750, 20.2500], + ], + } ) + expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype) assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"