enable torchao quantization on CPU (#36146)

* enable torchao quantization on CPU

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix int4

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable CPU torchao tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cuda tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cpu tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix style

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cuda tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torchao available

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torchao available

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torchao config cannot convert to json

* fix docs

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* rm to_dict to rebase

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* limited torchao version for CPU

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Update src/transformers/testing_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fix cpu test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng 2025-02-25 18:06:52 +08:00 committed by GitHub
parent 401543a825
commit 9d6abf9778
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 125 additions and 70 deletions

View File

@ -59,7 +59,7 @@ Use the table below to help you decide which quantization method to use.
| [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ | | [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto | | [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | | [torchao](./torchao.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ | | [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ | | [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | | | [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |

View File

@ -22,9 +22,11 @@ pip install --upgrade torch torchao transformers
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type. By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
## Manually Choose Quantization Types and Settings ## Manually Choose Quantization Types and Settings
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed. `torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
If you want to run the following codes on CPU even with GPU available, just change `device_map="cpu"` and `quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())` where `layout` comes from `from torchao.dtypes import Int4CPULayout` which is only available from torchao 0.8.0 and higher.
Users can manually specify the quantization types and settings they want to use: Users can manually specify the quantization types and settings they want to use:
@ -40,7 +42,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?" input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
# auto-compile the quantized model with `cache_implementation="static"` to get speedup # auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
@ -59,7 +61,7 @@ def benchmark_fn(func: Callable, *args, **kwargs) -> float:
MAX_NEW_TOKENS = 1000 MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
@ -122,7 +124,7 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)
# load quantized model # load quantized model
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda") loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto")
# confirm the speedup # confirm the speedup

View File

@ -45,6 +45,7 @@ from unittest.mock import patch
import huggingface_hub.utils import huggingface_hub.utils
import urllib3 import urllib3
from huggingface_hub import delete_repo from huggingface_hub import delete_repo
from packaging import version
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
@ -963,6 +964,18 @@ def require_torchao(test_case):
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
def require_torchao_version_greater_or_equal(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
)(test_case)
return decorator
def require_torch_tensorrt_fx(test_case): def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX""" """Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)

View File

@ -1558,7 +1558,17 @@ class TorchAoConfig(QuantizationConfigMixin):
def get_apply_tensor_subclass(self): def get_apply_tensor_subclass(self):
_STR_TO_METHOD = self._get_torchao_quant_type_to_method() _STR_TO_METHOD = self._get_torchao_quant_type_to_method()
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
):
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
def __repr__(self): def __repr__(self):
config_dict = self.to_dict() config_dict = self.to_dict()

View File

@ -14,15 +14,18 @@
# limitations under the License. # limitations under the License.
import gc import gc
import importlib.metadata
import tempfile import tempfile
import unittest import unittest
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torchao, require_torchao,
torch_device, require_torchao_version_greater_or_equal,
) )
from transformers.utils import is_torch_available, is_torchao_available from transformers.utils import is_torch_available, is_torchao_available
@ -38,13 +41,17 @@ if is_torchao_available():
) )
from torchao.quantization.autoquant import AQMixin from torchao.quantization.autoquant import AQMixin
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"):
from torchao.dtypes import Int4CPULayout
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
def check_torchao_int4_wo_quantized(test_module, qlayer):
weight = qlayer.weight weight = qlayer.weight
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
test_module.assertEqual(weight.quant_min, 0) test_module.assertEqual(weight.quant_min, 0)
test_module.assertEqual(weight.quant_max, 15) test_module.assertEqual(weight.quant_max, 15)
test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout)) test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
layout = Int4CPULayout if weight.device.type == "cpu" else TensorCoreTiledLayout
test_module.assertTrue(isinstance(weight.tensor_impl._layout, layout))
def check_autoquantized(test_module, qlayer): def check_autoquantized(test_module, qlayer):
@ -60,8 +67,8 @@ def check_forward(test_module, model, batch_size=1, context_size=1024):
test_module.assertEqual(out.shape[1], context_size) test_module.assertEqual(out.shape[1], context_size)
@require_torch_gpu
@require_torchao @require_torchao
@require_torchao_version_greater_or_equal("0.8.0")
class TorchAoConfigTest(unittest.TestCase): class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self): def test_to_dict(self):
""" """
@ -102,15 +109,19 @@ class TorchAoConfigTest(unittest.TestCase):
quantization_config.to_json_string(use_diff=False) quantization_config.to_json_string(use_diff=False)
@require_torch_gpu
@require_torchao @require_torchao
@require_torchao_version_greater_or_equal("0.8.0")
class TorchAoTest(unittest.TestCase): class TorchAoTest(unittest.TestCase):
input_text = "What are we having for dinner?" input_text = "What are we having for dinner?"
max_new_tokens = 10 max_new_tokens = 10
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cpu"
quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout()}
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
@ -121,20 +132,20 @@ class TorchAoTest(unittest.TestCase):
""" """
Simple LLM model testing int4 weight only quantization Simple LLM model testing int4 weight only quantization
""" """
quant_config = TorchAoConfig("int4_weight_only", group_size=32) quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
# Note: we quantize the bfloat16 model on the fly to int4 # Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=torch_device, device_map=self.device,
quantization_config=quant_config, quantization_config=quant_config,
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) check_torchao_int4_wo_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@ -143,46 +154,51 @@ class TorchAoTest(unittest.TestCase):
""" """
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
""" """
quant_config = TorchAoConfig("int4_weight_only", group_size=32) quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
# Note: we quantize the bfloat16 model on the fly to int4 # Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=None, torch_dtype=None,
device_map=torch_device, device_map=self.device,
quantization_config=quant_config, quantization_config=quant_config,
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) check_torchao_int4_wo_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu def test_int8_dynamic_activation_int8_weight_quant(self):
def test_int4wo_quant_multi_gpu(self):
""" """
Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs Simple LLM model testing int8_dynamic_activation_int8_weight
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
""" """
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.bfloat16, device_map=self.device,
device_map="auto",
quantization_config=quant_config, quantization_config=quant_config,
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torch_gpu
class TorchAoGPUTest(TorchAoTest):
device = "cuda"
quant_scheme_kwargs = {"group_size": 32}
def test_int4wo_offload(self): def test_int4wo_offload(self):
""" """
@ -228,32 +244,35 @@ class TorchAoTest(unittest.TestCase):
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside" EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside"
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
def test_int8_dynamic_activation_int8_weight_quant(self): @require_torch_multi_gpu
def test_int4wo_quant_multi_gpu(self):
""" """
Simple LLM model testing int8_dynamic_activation_int8_weight Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
""" """
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
# Note: we quantize the bfloat16 model on the fly to int4 quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
device_map=torch_device, torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quant_config, quantization_config=quant_config,
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
def test_autoquant(self): def test_autoquant(self):
""" """
@ -264,11 +283,11 @@ class TorchAoTest(unittest.TestCase):
quantized_model = AutoModelForCausalLM.from_pretrained( quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=torch_device, device_map=self.device,
quantization_config=quant_config, quantization_config=quant_config,
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate( output = quantized_model.generate(
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
) )
@ -283,8 +302,8 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
@require_torch_gpu
@require_torchao @require_torchao
@require_torchao_version_greater_or_equal("0.8.0")
class TorchAoSerializationTest(unittest.TestCase): class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?" input_text = "What are we having for dinner?"
max_new_tokens = 10 max_new_tokens = 10
@ -292,8 +311,13 @@ class TorchAoSerializationTest(unittest.TestCase):
# TODO: investigate why we don't have the same output as the original model for this test # TODO: investigate why we don't have the same output as the original model for this test
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} quant_scheme = "int4_weight_only"
device = "cuda:0" quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout()}
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)
device = "cpu"
# called only once for all test in this class # called only once for all test in this class
@classmethod @classmethod
@ -325,9 +349,9 @@ class TorchAoSerializationTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained( loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device self.model_name, torch_dtype=torch.bfloat16, device_map=device
) )
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
@ -336,46 +360,52 @@ class TorchAoSerializationTest(unittest.TestCase):
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
class TorchAoSerializationW8Test(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"
def test_serialization_expected_output_cuda(self): @require_torch_gpu
def test_serialization_expected_output_on_cuda(self):
""" """
Test if we can serialize on device (cpu) and load/infer the model on cuda Test if we can serialize on device (cpu) and load/infer the model on cuda
""" """
new_device = "cuda:0" self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"
def test_serialization_expected_output_cuda(self): @require_torch_gpu
def test_serialization_expected_output_on_cuda(self):
""" """
Test if we can serialize on device (cpu) and load/infer the model on cuda Test if we can serialize on device (cpu) and load/infer the model on cuda
""" """
new_device = "cuda:0" self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
@require_torch_gpu
class TorchAoSerializationGPTTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
device = "cuda:0"
@require_torch_gpu
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
@require_torch_gpu
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
if __name__ == "__main__": if __name__ == "__main__":