mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
parent
a5dfb98977
commit
d228f50acc
@ -19,7 +19,6 @@ from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_accelerate,
|
||||
require_quark,
|
||||
require_read_token,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
@ -44,7 +43,7 @@ class QuarkConfigTest(unittest.TestCase):
|
||||
@require_quark
|
||||
@require_torch_gpu
|
||||
class QuarkTest(unittest.TestCase):
|
||||
reference_model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
reference_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct"
|
||||
quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
|
||||
|
||||
input_text = "Today I am in Paris and"
|
||||
@ -76,13 +75,11 @@ class QuarkTest(unittest.TestCase):
|
||||
device_map=cls.device_map,
|
||||
)
|
||||
|
||||
@require_read_token
|
||||
def test_memory_footprint(self):
|
||||
mem_quantized = self.quantized_model.get_memory_footprint()
|
||||
|
||||
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||
|
||||
@require_read_token
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
|
||||
@ -96,7 +93,6 @@ class QuarkTest(unittest.TestCase):
|
||||
# Tries with a `dtype``
|
||||
self.quantized_model.to(torch.float16)
|
||||
|
||||
@require_read_token
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
@ -107,7 +103,6 @@ class QuarkTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
|
||||
|
||||
@require_read_token
|
||||
def check_inference_correctness(self, model):
|
||||
r"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
@ -131,7 +126,6 @@ class QuarkTest(unittest.TestCase):
|
||||
# Get the generation
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
@require_read_token
|
||||
def test_generate_quality(self):
|
||||
"""
|
||||
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
|
||||
|
Loading…
Reference in New Issue
Block a user