Fix torchao usage (#37034)

* fix load path

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix path

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix torchao usage

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert useless change

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert fp8 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix fp8 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix fp8 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torch dtype

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng 2025-04-07 20:50:48 +08:00 committed by GitHub
parent 0fb8d49e88
commit 99f9f1042f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 36 deletions

View File

@ -20,7 +20,7 @@ import dataclasses
import importlib.metadata
import json
import os
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from enum import Enum
from inspect import Parameter, signature
from typing import Any, Dict, List, Optional, Tuple, Union
@ -1627,6 +1627,7 @@ class TorchAoConfig(QuantizationConfigMixin):
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
and quant_type_kwargs.get("layout", None) is None
):
from torchao.dtypes import Int4CPULayout
@ -1643,7 +1644,17 @@ class TorchAoConfig(QuantizationConfigMixin):
if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
if is_dataclass(d["quant_type_kwargs"]["layout"]):
d["quant_type_kwargs"]["layout"] = [
d["quant_type_kwargs"]["layout"].__class__.__name__,
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
]
if isinstance(d["quant_type_kwargs"]["layout"], list):
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layour kwargs"
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
else:
raise ValueError("layout must be a list")
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
@ -1661,6 +1672,9 @@ class TorchAoConfig(QuantizationConfigMixin):
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (

View File

@ -104,8 +104,8 @@ class TorchAoConfigTest(unittest.TestCase):
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
d = quantization_config.to_dict()
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
quantization_config.to_json_string(use_diff=False)
@ -159,7 +159,7 @@ class TorchAoTest(unittest.TestCase):
# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=None,
torch_dtype=torch.bfloat16,
device_map=self.device,
quantization_config=quant_config,
)
@ -282,7 +282,7 @@ class TorchAoGPUTest(TorchAoTest):
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
torch_dtype="auto",
device_map=self.device,
quantization_config=quant_config,
)
@ -295,7 +295,7 @@ class TorchAoGPUTest(TorchAoTest):
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
output = quantized_model.generate(
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
)
@ -307,9 +307,7 @@ class TorchAoGPUTest(TorchAoTest):
class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 10
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
# TODO: investigate why we don't have the same output as the original model for this test
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_scheme = "int4_weight_only"
quant_scheme_kwargs = (
@ -326,9 +324,10 @@ class TorchAoSerializationTest(unittest.TestCase):
def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
torch_dtype=torch_dtype,
device_map=self.device,
quantization_config=self.quant_config,
)
@ -342,16 +341,17 @@ class TorchAoSerializationTest(unittest.TestCase):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def check_serialization_expected_output(self, device, expected_output):
"""
Test if we can serialize and load/infer the model again on the same device
"""
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, device_map=device
tmpdirname, torch_dtype=torch_dtype, device_map=device
)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
@ -359,33 +359,31 @@ class TorchAoSerializationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
def test_serialization_expected_output(self):
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
@require_torch_gpu
def test_serialization_expected_output_on_cuda(self):
"""
Test if we can serialize on device (cpu) and load/infer the model on cuda
"""
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
@require_torch_gpu
def test_serialization_expected_output_on_cuda(self):
"""
Test if we can serialize on device (cpu) and load/infer the model on cuda
"""
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
@require_torch_gpu
@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
@require_torch_gpu
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
device = "cuda:0"
@require_torch_gpu
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
device = "cuda:0"
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
device = "cuda:0"
def setUp(self):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Float8WeightOnlyConfig
self.quant_scheme = Float8WeightOnlyConfig()
self.quant_scheme_kwargs = {}
super().setUp()
cls.quant_scheme = Float8WeightOnlyConfig()
cls.quant_scheme_kwargs = {}
super().setUpClass()
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
device = "cuda:0"
def setUp(self):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
self.quant_scheme_kwargs = {}
super().setUp()
cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
cls.quant_scheme_kwargs = {}
super().setUpClass()
if __name__ == "__main__":