update ColQwen2ModelIntegrationTest (#38583)

* update

* update

* update

* update

* 4 bit

* 8 bit

* final

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-06 10:41:17 +02:00 committed by GitHub
parent dbfc79c17c
commit 92a87134ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,7 +26,15 @@ from transformers import is_torch_available
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2ForRetrievalOutput
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
from transformers.testing_utils import cleanup, require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
Expectations,
cleanup,
require_bitsandbytes,
require_torch,
require_vision,
slow,
torch_device,
)
if is_torch_available():
@ -283,6 +291,7 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@require_bitsandbytes
@slow
def test_model_integration_test(self):
"""
@ -291,7 +300,7 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase):
model = ColQwen2ForRetrieval.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
load_in_8bit=True,
).eval()
# Load the test dataset
@ -319,13 +328,20 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase):
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original Hf implementation.
expected_scores = torch.tensor(
[
[16.2500, 7.8750, 14.6875],
[9.5000, 17.1250, 10.5000],
[14.9375, 10.9375, 20.0000],
],
dtype=scores.dtype,
expectations = Expectations(
{
("cuda", 7): [
[15.5000, 8.1250, 14.9375],
[9.0625, 17.1250, 10.6875],
[15.9375, 12.1875, 20.2500],
],
("cuda", 8): [
[15.1250, 8.6875, 15.0625],
[9.2500, 17.2500, 10.3750],
[15.9375, 12.3750, 20.2500],
],
}
)
expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype)
assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"