Fix CI by tweaking torchao tests (#34832)

This commit is contained in:
Marc Sun 2024-11-20 20:28:51 +01:00 committed by GitHub
parent bf42c3bd4b
commit 3cb8676a91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 7 deletions

View File

@ -1264,8 +1264,13 @@ class TorchAoConfig(QuantizationConfigMixin):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
if is_torchao_available():
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
if self.quant_type not in _STR_TO_METHOD.keys():

View File

@ -246,12 +246,13 @@ class TorchAoSerializationTest(unittest.TestCase):
# TODO: investigate why we don't have the same output as the original model for this test
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
device = "cuda:0"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.bfloat16,
@ -290,21 +291,21 @@ class TorchAoSerializationTest(unittest.TestCase):
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
class TorchAoSerializationW8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"
@ -318,7 +319,7 @@ class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"