Fix : HQQ config when hqq not available (#35655)

* fix

* make style

* adding require_hqq

* make style
This commit is contained in:
Mohamed Mekkouri 2025-01-14 11:37:37 +01:00 committed by GitHub
parent 715fdd6459
commit 050636518a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 0 deletions

View File

@ -87,6 +87,7 @@ from .utils import (
is_gguf_available,
is_grokadamw_available,
is_hadamard_available,
is_hqq_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
@ -1213,6 +1214,13 @@ def require_auto_gptq(test_case):
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
def require_hqq(test_case):
"""
Decorator for hqq dependency
"""
return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
def require_auto_awq(test_case):
"""
Decorator for auto_awq dependency

View File

@ -224,6 +224,10 @@ class HqqConfig(QuantizationConfigMixin):
):
if is_hqq_available():
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
else:
raise ImportError(
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
)
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
if deprecated_key in kwargs:

View File

@ -19,6 +19,7 @@ import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
from transformers.testing_utils import (
require_accelerate,
require_hqq,
require_torch_gpu,
require_torch_multi_gpu,
slow,
@ -86,6 +87,7 @@ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@require_torch_gpu
@require_hqq
class HqqConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
@ -100,6 +102,7 @@ class HqqConfigTest(unittest.TestCase):
@slow
@require_torch_gpu
@require_accelerate
@require_hqq
class HQQTest(unittest.TestCase):
def tearDown(self):
cleanup()
@ -122,6 +125,7 @@ class HQQTest(unittest.TestCase):
@require_torch_gpu
@require_torch_multi_gpu
@require_accelerate
@require_hqq
class HQQTestMultiGPU(unittest.TestCase):
def tearDown(self):
cleanup()
@ -144,6 +148,7 @@ class HQQTestMultiGPU(unittest.TestCase):
@slow
@require_torch_gpu
@require_accelerate
@require_hqq
class HQQSerializationTest(unittest.TestCase):
def tearDown(self):
cleanup()