mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix CI by tweaking torchao tests (#34832)
This commit is contained in:
parent
bf42c3bd4b
commit
3cb8676a91
@ -1264,8 +1264,13 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||
if is_torchao_available():
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
|
||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||
if self.quant_type not in _STR_TO_METHOD.keys():
|
||||
|
@ -246,12 +246,13 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
# TODO: investigate why we don't have the same output as the original model for this test
|
||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
|
||||
device = "cuda:0"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
@ -290,21 +291,21 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_weight_only")
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
@ -318,7 +319,7 @@ class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
|
||||
|
||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_weight_only")
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
Loading…
Reference in New Issue
Block a user