Fix require_read_token (#37422)

* nit

* fix

* fix
This commit is contained in:
Mohamed Mekkouri 2025-04-10 17:01:40 +02:00 committed by GitHub
parent bde41d69b4
commit 9c0c323e12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 3 deletions

View File

@ -156,7 +156,7 @@ class DonutSwinImageClassifierOutput(ModelOutput):
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None

View File

@ -57,7 +57,6 @@ class QuarkTest(unittest.TestCase):
EXPECTED_RELATIVE_DIFFERENCE = 1.66
device_map = None
@require_read_token
@classmethod
def setUpClass(cls):
"""
@ -76,15 +75,17 @@ class QuarkTest(unittest.TestCase):
device_map=cls.device_map,
)
@require_read_token
def test_memory_footprint(self):
mem_quantized = self.quantized_model.get_memory_footprint()
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
@require_read_token
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
Checks also if other models are casted correctly.
Checks also if other models are casted correctly .
"""
# This should work
if self.device_map is None:
@ -94,6 +95,7 @@ class QuarkTest(unittest.TestCase):
# Tries with a `dtype``
self.quantized_model.to(torch.float16)
@require_read_token
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
@ -104,6 +106,7 @@ class QuarkTest(unittest.TestCase):
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
@require_read_token
def check_inference_correctness(self, model):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
@ -127,6 +130,7 @@ class QuarkTest(unittest.TestCase):
# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_read_token
def test_generate_quality(self):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens