mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix torchao usage (#37034)
* fix load path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix torchao usage Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert useless change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
0fb8d49e88
commit
99f9f1042f
@ -20,7 +20,7 @@ import dataclasses
|
|||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import Parameter, signature
|
from inspect import Parameter, signature
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
@ -1627,6 +1627,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
and is_torchao_available()
|
and is_torchao_available()
|
||||||
and self.quant_type == "int4_weight_only"
|
and self.quant_type == "int4_weight_only"
|
||||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||||
|
and quant_type_kwargs.get("layout", None) is None
|
||||||
):
|
):
|
||||||
from torchao.dtypes import Int4CPULayout
|
from torchao.dtypes import Int4CPULayout
|
||||||
|
|
||||||
@ -1643,7 +1644,17 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
if isinstance(self.quant_type, str):
|
if isinstance(self.quant_type, str):
|
||||||
# Handle layout serialization if present
|
# Handle layout serialization if present
|
||||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||||
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
|
if is_dataclass(d["quant_type_kwargs"]["layout"]):
|
||||||
|
d["quant_type_kwargs"]["layout"] = [
|
||||||
|
d["quant_type_kwargs"]["layout"].__class__.__name__,
|
||||||
|
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
|
||||||
|
]
|
||||||
|
if isinstance(d["quant_type_kwargs"]["layout"], list):
|
||||||
|
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layour kwargs"
|
||||||
|
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
|
||||||
|
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
|
||||||
|
else:
|
||||||
|
raise ValueError("layout must be a list")
|
||||||
else:
|
else:
|
||||||
# Handle AOBaseConfig serialization
|
# Handle AOBaseConfig serialization
|
||||||
from torchao.core.config import config_to_dict
|
from torchao.core.config import config_to_dict
|
||||||
@ -1661,6 +1672,9 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
|
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
|
||||||
config_dict = config_dict.copy()
|
config_dict = config_dict.copy()
|
||||||
quant_type = config_dict.pop("quant_type")
|
quant_type = config_dict.pop("quant_type")
|
||||||
|
|
||||||
|
if isinstance(quant_type, str):
|
||||||
|
return cls(quant_type=quant_type, **config_dict)
|
||||||
# Check if we only have one key which is "default"
|
# Check if we only have one key which is "default"
|
||||||
# In the future we may update this
|
# In the future we may update this
|
||||||
assert len(quant_type) == 1 and "default" in quant_type, (
|
assert len(quant_type) == 1 and "default" in quant_type, (
|
||||||
|
@ -104,8 +104,8 @@ class TorchAoConfigTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
||||||
d = quantization_config.to_dict()
|
d = quantization_config.to_dict()
|
||||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
|
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
|
||||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
|
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
|
||||||
quantization_config.to_json_string(use_diff=False)
|
quantization_config.to_json_string(use_diff=False)
|
||||||
|
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
# Note: we quantize the bfloat16 model on the fly to int4
|
# Note: we quantize the bfloat16 model on the fly to int4
|
||||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=None,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
quantization_config=quant_config,
|
quantization_config=quant_config,
|
||||||
)
|
)
|
||||||
@ -282,7 +282,7 @@ class TorchAoGPUTest(TorchAoTest):
|
|||||||
|
|
||||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype="auto",
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
quantization_config=quant_config,
|
quantization_config=quant_config,
|
||||||
)
|
)
|
||||||
@ -295,7 +295,7 @@ class TorchAoGPUTest(TorchAoTest):
|
|||||||
|
|
||||||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||||
|
|
||||||
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
|
||||||
output = quantized_model.generate(
|
output = quantized_model.generate(
|
||||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||||
)
|
)
|
||||||
@ -307,9 +307,7 @@ class TorchAoGPUTest(TorchAoTest):
|
|||||||
class TorchAoSerializationTest(unittest.TestCase):
|
class TorchAoSerializationTest(unittest.TestCase):
|
||||||
input_text = "What are we having for dinner?"
|
input_text = "What are we having for dinner?"
|
||||||
max_new_tokens = 10
|
max_new_tokens = 10
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||||
# TODO: investigate why we don't have the same output as the original model for this test
|
|
||||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
|
||||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
quant_scheme = "int4_weight_only"
|
quant_scheme = "int4_weight_only"
|
||||||
quant_scheme_kwargs = (
|
quant_scheme_kwargs = (
|
||||||
@ -326,9 +324,10 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
|
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
|
||||||
|
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
|
||||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch_dtype,
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
quantization_config=self.quant_config,
|
quantization_config=self.quant_config,
|
||||||
)
|
)
|
||||||
@ -342,16 +341,17 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
|
||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
def check_serialization_expected_output(self, device, expected_output):
|
def check_serialization_expected_output(self, device, expected_output):
|
||||||
"""
|
"""
|
||||||
Test if we can serialize and load/infer the model again on the same device
|
Test if we can serialize and load/infer the model again on the same device
|
||||||
"""
|
"""
|
||||||
|
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name, torch_dtype=torch.bfloat16, device_map=device
|
tmpdirname, torch_dtype=torch_dtype, device_map=device
|
||||||
)
|
)
|
||||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
|
||||||
|
|
||||||
@ -359,33 +359,31 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||||
|
|
||||||
def test_serialization_expected_output(self):
|
def test_serialization_expected_output(self):
|
||||||
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_serialization_expected_output_on_cuda(self):
|
def test_serialization_expected_output_on_cuda(self):
|
||||||
"""
|
"""
|
||||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||||
"""
|
"""
|
||||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_serialization_expected_output_on_cuda(self):
|
def test_serialization_expected_output_on_cuda(self):
|
||||||
"""
|
"""
|
||||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||||
"""
|
"""
|
||||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torchao_version_greater_or_equal("0.10.0")
|
@require_torchao_version_greater_or_equal("0.10.0")
|
||||||
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
def setUp(self):
|
# called only once for all test in this class
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
from torchao.quantization import Float8WeightOnlyConfig
|
from torchao.quantization import Float8WeightOnlyConfig
|
||||||
|
|
||||||
self.quant_scheme = Float8WeightOnlyConfig()
|
cls.quant_scheme = Float8WeightOnlyConfig()
|
||||||
self.quant_scheme_kwargs = {}
|
cls.quant_scheme_kwargs = {}
|
||||||
super().setUp()
|
|
||||||
|
super().setUpClass()
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torchao_version_greater_or_equal("0.10.0")
|
@require_torchao_version_greater_or_equal("0.10.0")
|
||||||
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
def setUp(self):
|
# called only once for all test in this class
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||||
|
|
||||||
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||||
self.quant_scheme_kwargs = {}
|
cls.quant_scheme_kwargs = {}
|
||||||
super().setUp()
|
|
||||||
|
super().setUpClass()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user