mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix torchao usage (#37034)
* fix load path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix torchao usage Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert useless change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
0fb8d49e88
commit
99f9f1042f
@ -20,7 +20,7 @@ import dataclasses
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
@ -1627,6 +1627,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
and quant_type_kwargs.get("layout", None) is None
|
||||
):
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
@ -1643,7 +1644,17 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
if isinstance(self.quant_type, str):
|
||||
# Handle layout serialization if present
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
|
||||
if is_dataclass(d["quant_type_kwargs"]["layout"]):
|
||||
d["quant_type_kwargs"]["layout"] = [
|
||||
d["quant_type_kwargs"]["layout"].__class__.__name__,
|
||||
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
|
||||
]
|
||||
if isinstance(d["quant_type_kwargs"]["layout"], list):
|
||||
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layour kwargs"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
|
||||
else:
|
||||
raise ValueError("layout must be a list")
|
||||
else:
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
@ -1661,6 +1672,9 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
if isinstance(quant_type, str):
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
# Check if we only have one key which is "default"
|
||||
# In the future we may update this
|
||||
assert len(quant_type) == 1 and "default" in quant_type, (
|
||||
|
@ -104,8 +104,8 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
||||
d = quantization_config.to_dict()
|
||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
|
||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
|
||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
|
||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
|
||||
quantization_config.to_json_string(use_diff=False)
|
||||
|
||||
|
||||
@ -159,7 +159,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=None,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
@ -282,7 +282,7 @@ class TorchAoGPUTest(TorchAoTest):
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype="auto",
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
@ -295,7 +295,7 @@ class TorchAoGPUTest(TorchAoTest):
|
||||
|
||||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
@ -307,9 +307,7 @@ class TorchAoGPUTest(TorchAoTest):
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
# TODO: investigate why we don't have the same output as the original model for this test
|
||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
quant_scheme = "int4_weight_only"
|
||||
quant_scheme_kwargs = (
|
||||
@ -326,9 +324,10 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
|
||||
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
|
||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=self.device,
|
||||
quantization_config=self.quant_config,
|
||||
)
|
||||
@ -342,16 +341,17 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def check_serialization_expected_output(self, device, expected_output):
|
||||
"""
|
||||
Test if we can serialize and load/infer the model again on the same device
|
||||
"""
|
||||
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=torch.bfloat16, device_map=device
|
||||
tmpdirname, torch_dtype=torch_dtype, device_map=device
|
||||
)
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
|
||||
|
||||
@ -359,33 +359,31 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
|
||||
def test_serialization_expected_output(self):
|
||||
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_serialization_expected_output_on_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_serialization_expected_output_on_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.10.0")
|
||||
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
def setUp(self):
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||
|
||||
from torchao.quantization import Float8WeightOnlyConfig
|
||||
|
||||
self.quant_scheme = Float8WeightOnlyConfig()
|
||||
self.quant_scheme_kwargs = {}
|
||||
super().setUp()
|
||||
cls.quant_scheme = Float8WeightOnlyConfig()
|
||||
cls.quant_scheme_kwargs = {}
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.10.0")
|
||||
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
def setUp(self):
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||
|
||||
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||
|
||||
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||
self.quant_scheme_kwargs = {}
|
||||
super().setUp()
|
||||
cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||
cls.quant_scheme_kwargs = {}
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user