Fix torchao usage (#37034)

* fix load path

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix path

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix torchao usage

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert useless change

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert fp8 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix fp8 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix fp8 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torch dtype

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng 2025-04-07 20:50:48 +08:00 committed by GitHub
parent 0fb8d49e88
commit 99f9f1042f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 36 deletions

View File

@ -20,7 +20,7 @@ import dataclasses
import importlib.metadata import importlib.metadata
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass, is_dataclass
from enum import Enum from enum import Enum
from inspect import Parameter, signature from inspect import Parameter, signature
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -1627,6 +1627,7 @@ class TorchAoConfig(QuantizationConfigMixin):
and is_torchao_available() and is_torchao_available()
and self.quant_type == "int4_weight_only" and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
and quant_type_kwargs.get("layout", None) is None
): ):
from torchao.dtypes import Int4CPULayout from torchao.dtypes import Int4CPULayout
@ -1643,7 +1644,17 @@ class TorchAoConfig(QuantizationConfigMixin):
if isinstance(self.quant_type, str): if isinstance(self.quant_type, str):
# Handle layout serialization if present # Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"]) if is_dataclass(d["quant_type_kwargs"]["layout"]):
d["quant_type_kwargs"]["layout"] = [
d["quant_type_kwargs"]["layout"].__class__.__name__,
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
]
if isinstance(d["quant_type_kwargs"]["layout"], list):
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layour kwargs"
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
else:
raise ValueError("layout must be a list")
else: else:
# Handle AOBaseConfig serialization # Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict from torchao.core.config import config_to_dict
@ -1661,6 +1672,9 @@ class TorchAoConfig(QuantizationConfigMixin):
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict" assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
config_dict = config_dict.copy() config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type") quant_type = config_dict.pop("quant_type")
if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
# Check if we only have one key which is "default" # Check if we only have one key which is "default"
# In the future we may update this # In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, ( assert len(quant_type) == 1 and "default" in quant_type, (

View File

@ -104,8 +104,8 @@ class TorchAoConfigTest(unittest.TestCase):
""" """
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout()) quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
d = quantization_config.to_dict() d = quantization_config.to_dict()
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict) self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"]) self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
quantization_config.to_json_string(use_diff=False) quantization_config.to_json_string(use_diff=False)
@ -159,7 +159,7 @@ class TorchAoTest(unittest.TestCase):
# Note: we quantize the bfloat16 model on the fly to int4 # Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=None, torch_dtype=torch.bfloat16,
device_map=self.device, device_map=self.device,
quantization_config=quant_config, quantization_config=quant_config,
) )
@ -282,7 +282,7 @@ class TorchAoGPUTest(TorchAoTest):
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype="auto",
device_map=self.device, device_map=self.device,
quantization_config=quant_config, quantization_config=quant_config,
) )
@ -295,7 +295,7 @@ class TorchAoGPUTest(TorchAoTest):
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj) check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready' EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
output = quantized_model.generate( output = quantized_model.generate(
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
) )
@ -307,9 +307,7 @@ class TorchAoGPUTest(TorchAoTest):
class TorchAoSerializationTest(unittest.TestCase): class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?" input_text = "What are we having for dinner?"
max_new_tokens = 10 max_new_tokens = 10
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
# TODO: investigate why we don't have the same output as the original model for this test
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_scheme = "int4_weight_only" quant_scheme = "int4_weight_only"
quant_scheme_kwargs = ( quant_scheme_kwargs = (
@ -326,9 +324,10 @@ class TorchAoSerializationTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs) self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
self.quantized_model = AutoModelForCausalLM.from_pretrained( self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch_dtype,
device_map=self.device, device_map=self.device,
quantization_config=self.quant_config, quantization_config=self.quant_config,
) )
@ -342,16 +341,17 @@ class TorchAoSerializationTest(unittest.TestCase):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def check_serialization_expected_output(self, device, expected_output): def check_serialization_expected_output(self, device, expected_output):
""" """
Test if we can serialize and load/infer the model again on the same device Test if we can serialize and load/infer the model again on the same device
""" """
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained( loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, device_map=device tmpdirname, torch_dtype=torch_dtype, device_map=device
) )
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
@ -359,33 +359,31 @@ class TorchAoSerializationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
def test_serialization_expected_output(self): def test_serialization_expected_output(self):
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
@require_torch_gpu @require_torch_gpu
def test_serialization_expected_output_on_cuda(self): def test_serialization_expected_output_on_cuda(self):
""" """
Test if we can serialize on device (cpu) and load/infer the model on cuda Test if we can serialize on device (cpu) and load/infer the model on cuda
""" """
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT) self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
@require_torch_gpu @require_torch_gpu
def test_serialization_expected_output_on_cuda(self): def test_serialization_expected_output_on_cuda(self):
""" """
Test if we can serialize on device (cpu) and load/infer the model on cuda Test if we can serialize on device (cpu) and load/infer the model on cuda
""" """
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT) self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
@require_torch_gpu @require_torch_gpu
@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
@require_torch_gpu @require_torch_gpu
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest): class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0" device = "cuda:0"
@require_torch_gpu @require_torch_gpu
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest): class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0" device = "cuda:0"
@require_torch_gpu @require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0") @require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest): class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0" device = "cuda:0"
def setUp(self): # called only once for all test in this class
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Float8WeightOnlyConfig from torchao.quantization import Float8WeightOnlyConfig
self.quant_scheme = Float8WeightOnlyConfig() cls.quant_scheme = Float8WeightOnlyConfig()
self.quant_scheme_kwargs = {} cls.quant_scheme_kwargs = {}
super().setUp()
super().setUpClass()
@require_torch_gpu @require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0") @require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationA8W4Test(TorchAoSerializationTest): class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0" device = "cuda:0"
def setUp(self): # called only once for all test in this class
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Int8DynamicActivationInt4WeightConfig from torchao.quantization import Int8DynamicActivationInt4WeightConfig
self.quant_scheme = Int8DynamicActivationInt4WeightConfig() cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
self.quant_scheme_kwargs = {} cls.quant_scheme_kwargs = {}
super().setUp()
super().setUpClass()
if __name__ == "__main__": if __name__ == "__main__":