mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Torchao weights only + prequantized compability (#34355)
* weights only compability * better tests from code review * ping torch version * add weights_only check
This commit is contained in:
parent
f297af55df
commit
67890de3b8
@ -3602,7 +3602,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.validate_environment(
|
||||
torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
|
||||
torch_dtype=torch_dtype,
|
||||
from_tf=from_tf,
|
||||
from_flax=from_flax,
|
||||
device_map=device_map,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
device_map = hf_quantizer.update_device_map(device_map)
|
||||
|
@ -91,6 +91,15 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
)
|
||||
else:
|
||||
self.offload = True
|
||||
if self.pre_quantized:
|
||||
weights_only = kwargs.get("weights_only", None)
|
||||
if weights_only:
|
||||
torch_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torch_version < version.parse("2.5.0"):
|
||||
raise RuntimeError(
|
||||
f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
|
||||
f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
if self.quantization_config.quant_type == "int4_weight_only":
|
||||
@ -103,6 +112,10 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
|
||||
if torch_dtype is None:
|
||||
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
@ -198,6 +211,12 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
)
|
||||
if not _is_torchao_serializable:
|
||||
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
|
||||
if self.offload and self.quantization_config.modules_to_not_convert is None:
|
||||
logger.warning(
|
||||
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
|
||||
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
|
||||
)
|
||||
return False
|
||||
return _is_torchao_serializable
|
||||
|
||||
@property
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
@ -236,5 +237,99 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
# TODO: investigate why we don't have the same output as the original model for this test
|
||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
device = "cuda:0"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=cls.device,
|
||||
quantization_config=cls.quant_config,
|
||||
)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_original_model_expected_output(self):
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
|
||||
|
||||
def check_serialization_expected_output(self, device, expected_output):
|
||||
"""
|
||||
Test if we can serialize and load/infer the model again on the same device
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device
|
||||
)
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
|
||||
def test_serialization_expected_output(self):
|
||||
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_weight_only")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
||||
def test_serialization_expected_output_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
new_device = "cuda:0"
|
||||
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_weight_only")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
||||
def test_serialization_expected_output_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
new_device = "cuda:0"
|
||||
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user