mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
enable torchao quantization on CPU (#36146)
* enable torchao quantization on CPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix int4 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable CPU torchao tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cuda tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cpu tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix style Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cuda tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torchao available Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torchao available Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torchao config cannot convert to json * fix docs Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm to_dict to rebase Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * limited torchao version for CPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix skip Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Update src/transformers/testing_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix cpu test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
401543a825
commit
9d6abf9778
@ -59,7 +59,7 @@ Use the table below to help you decide which quantization method to use.
|
||||
| [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
| [torchao](./torchao.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
|
@ -22,9 +22,11 @@ pip install --upgrade torch torchao transformers
|
||||
|
||||
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
|
||||
|
||||
|
||||
## Manually Choose Quantization Types and Settings
|
||||
|
||||
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
|
||||
If you want to run the following codes on CPU even with GPU available, just change `device_map="cpu"` and `quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())` where `layout` comes from `from torchao.dtypes import Int4CPULayout` which is only available from torchao 0.8.0 and higher.
|
||||
|
||||
Users can manually specify the quantization types and settings they want to use:
|
||||
|
||||
@ -40,7 +42,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
@ -59,7 +61,7 @@ def benchmark_fn(func: Callable, *args, **kwargs) -> float:
|
||||
MAX_NEW_TOKENS = 1000
|
||||
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
|
||||
|
||||
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
|
||||
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
|
||||
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
|
||||
|
||||
@ -122,7 +124,7 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
||||
|
||||
# load quantized model
|
||||
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto")
|
||||
|
||||
|
||||
# confirm the speedup
|
||||
|
@ -45,6 +45,7 @@ from unittest.mock import patch
|
||||
import huggingface_hub.utils
|
||||
import urllib3
|
||||
from huggingface_hub import delete_repo
|
||||
from packaging import version
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
@ -963,6 +964,18 @@ def require_torchao(test_case):
|
||||
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
|
||||
|
||||
|
||||
def require_torchao_version_greater_or_equal(torchao_version):
|
||||
def decorator(test_case):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("torchao")).base_version
|
||||
) >= version.parse(torchao_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_tensorrt_fx(test_case):
|
||||
"""Decorator marking a test that requires Torch-TensorRT FX"""
|
||||
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
|
||||
|
@ -1558,7 +1558,17 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
|
||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
):
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
|
@ -14,15 +14,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torchao,
|
||||
torch_device,
|
||||
require_torchao_version_greater_or_equal,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_torchao_available
|
||||
|
||||
@ -38,13 +41,17 @@ if is_torchao_available():
|
||||
)
|
||||
from torchao.quantization.autoquant import AQMixin
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"):
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||
|
||||
def check_torchao_int4_wo_quantized(test_module, qlayer):
|
||||
weight = qlayer.weight
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
test_module.assertEqual(weight.quant_min, 0)
|
||||
test_module.assertEqual(weight.quant_max, 15)
|
||||
test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout))
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
layout = Int4CPULayout if weight.device.type == "cpu" else TensorCoreTiledLayout
|
||||
test_module.assertTrue(isinstance(weight.tensor_impl._layout, layout))
|
||||
|
||||
|
||||
def check_autoquantized(test_module, qlayer):
|
||||
@ -60,8 +67,8 @@ def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||
test_module.assertEqual(out.shape[1], context_size)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
@require_torchao_version_greater_or_equal("0.8.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@ -102,15 +109,19 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
quantization_config.to_json_string(use_diff=False)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
@require_torchao_version_greater_or_equal("0.8.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
device = "cpu"
|
||||
quant_scheme_kwargs = (
|
||||
{"group_size": 32, "layout": Int4CPULayout()}
|
||||
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
else {"group_size": 32}
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@ -121,20 +132,20 @@ class TorchAoTest(unittest.TestCase):
|
||||
"""
|
||||
Simple LLM model testing int4 weight only quantization
|
||||
"""
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
|
||||
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=torch_device,
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
check_torchao_int4_wo_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
@ -143,46 +154,51 @@ class TorchAoTest(unittest.TestCase):
|
||||
"""
|
||||
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
|
||||
"""
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
|
||||
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=None,
|
||||
device_map=torch_device,
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
check_torchao_int4_wo_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_int4wo_quant_multi_gpu(self):
|
||||
def test_int8_dynamic_activation_int8_weight_quant(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||
Simple LLM model testing int8_dynamic_activation_int8_weight
|
||||
"""
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
EXPECTED_OUTPUT = [
|
||||
"What are we having for dinner?\n\nJessica: (smiling)",
|
||||
"What are we having for dinner?\n\nJess: (smiling) I",
|
||||
]
|
||||
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class TorchAoGPUTest(TorchAoTest):
|
||||
device = "cuda"
|
||||
quant_scheme_kwargs = {"group_size": 32}
|
||||
|
||||
def test_int4wo_offload(self):
|
||||
"""
|
||||
@ -228,32 +244,35 @@ class TorchAoTest(unittest.TestCase):
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside"
|
||||
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
def test_int8_dynamic_activation_int8_weight_quant(self):
|
||||
@require_torch_multi_gpu
|
||||
def test_int4wo_quant_multi_gpu(self):
|
||||
"""
|
||||
Simple LLM model testing int8_dynamic_activation_int8_weight
|
||||
Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||
"""
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=torch_device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_autoquant(self):
|
||||
"""
|
||||
@ -264,11 +283,11 @@ class TorchAoTest(unittest.TestCase):
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=torch_device,
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
@ -283,8 +302,8 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
@require_torchao_version_greater_or_equal("0.8.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
@ -292,8 +311,13 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
# TODO: investigate why we don't have the same output as the original model for this test
|
||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
|
||||
device = "cuda:0"
|
||||
quant_scheme = "int4_weight_only"
|
||||
quant_scheme_kwargs = (
|
||||
{"group_size": 32, "layout": Int4CPULayout()}
|
||||
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
else {"group_size": 32}
|
||||
)
|
||||
device = "cpu"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
@ -325,9 +349,9 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device
|
||||
self.model_name, torch_dtype=torch.bfloat16, device_map=device
|
||||
)
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
|
||||
|
||||
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
@ -336,46 +360,52 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
||||
def test_serialization_expected_output_cuda(self):
|
||||
@require_torch_gpu
|
||||
def test_serialization_expected_output_on_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
new_device = "cuda:0"
|
||||
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
||||
def test_serialization_expected_output_cuda(self):
|
||||
@require_torch_gpu
|
||||
def test_serialization_expected_output_on_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
new_device = "cuda:0"
|
||||
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationGPTTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user