mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix HQQ model param device transfer issue (#38466)
* Fix HQQ model param device transfer issue * modify a comment * clear the code and add test for hqq device/dtype * fix test hqq code quality of imports --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
c77bcd889f
commit
cb0f604192
@ -3897,7 +3897,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
@wraps(torch.nn.Module.cuda)
|
@wraps(torch.nn.Module.cuda)
|
||||||
def cuda(self, *args, **kwargs):
|
def cuda(self, *args, **kwargs):
|
||||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||||
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
|
# Since HQQLinear stores some tensors in the 'meta' attribute,
|
||||||
|
# it's necessary to manually call the `cuda` method on HQQLinear layers.
|
||||||
|
super().cuda(*args, **kwargs)
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, HQQLinear):
|
||||||
|
if len(args) > 0:
|
||||||
|
device = args[0]
|
||||||
|
else:
|
||||||
|
device = kwargs.get("device", "cuda")
|
||||||
|
module.cuda(device)
|
||||||
|
return self
|
||||||
|
|
||||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||||
if getattr(self, "is_loaded_in_8bit", False):
|
if getattr(self, "is_loaded_in_8bit", False):
|
||||||
@ -3910,8 +3923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||||
)
|
)
|
||||||
else:
|
return super().cuda(*args, **kwargs)
|
||||||
return super().cuda(*args, **kwargs)
|
|
||||||
|
|
||||||
@wraps(torch.nn.Module.to)
|
@wraps(torch.nn.Module.to)
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
@ -3926,7 +3938,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
break
|
break
|
||||||
|
|
||||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||||
raise ValueError("`.to` is not supported for HQQ-quantized models.")
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
|
# Since HQQLinear stores some tensors in the 'meta' attribute, we must
|
||||||
|
# explicitly move the parameters to the target device for each HQQLinear layer after `to`.
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, HQQLinear):
|
||||||
|
if "device" in kwargs:
|
||||||
|
device = kwargs["device"]
|
||||||
|
else:
|
||||||
|
device = args[0]
|
||||||
|
if "dtype" in kwargs:
|
||||||
|
dtype = kwargs["dtype"]
|
||||||
|
elif dtype_present_in_args:
|
||||||
|
dtype = arg
|
||||||
|
else:
|
||||||
|
dtype = None
|
||||||
|
# Due to the current messy implementation of HQQLinear, updating `compute_dtype`
|
||||||
|
# followed by calling the `cuda` method achieves the intended behavior of `to`,
|
||||||
|
# even when the target device is CPU.
|
||||||
|
if dtype is not None:
|
||||||
|
module.compute_dtype = dtype
|
||||||
|
module.cuda(device)
|
||||||
|
return self
|
||||||
|
|
||||||
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
|
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
|
||||||
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
|
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
|
||||||
|
@ -202,6 +202,15 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
if is_hqq_available():
|
if is_hqq_available():
|
||||||
from hqq.core.quantize import HQQLinear
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
|
# TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
|
||||||
|
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
|
||||||
|
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
|
||||||
|
@property
|
||||||
|
def weight(_self: HQQLinear):
|
||||||
|
return torch.empty(0, dtype=_self.compute_dtype, device=_self.device)
|
||||||
|
|
||||||
|
HQQLinear.weight = weight
|
||||||
|
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
layer_name = ".".join(param_name.split(".")[:-1])
|
layer_name = ".".join(param_name.split(".")[:-1])
|
||||||
parent_module = find_parent(model, layer_name)
|
parent_module = find_parent(model, layer_name)
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
@ -119,6 +121,41 @@ class HQQTest(unittest.TestCase):
|
|||||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||||
check_forward(self, hqq_runner.model)
|
check_forward(self, hqq_runner.model)
|
||||||
|
|
||||||
|
def test_quantized_model_to_new_device_and_new_dtype(self):
|
||||||
|
"""
|
||||||
|
Simple LLM model testing different devices and dtypes
|
||||||
|
"""
|
||||||
|
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||||
|
|
||||||
|
hqq_runner = HQQLLMRunner(
|
||||||
|
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
original_device = hqq_runner.model.model.layers[0].self_attn.v_proj.device
|
||||||
|
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||||
|
check_forward(self, hqq_runner.model)
|
||||||
|
|
||||||
|
# Remove `accelerate` hooks to enable move the model to a new device
|
||||||
|
accelerate.hooks.remove_hook_from_module(hqq_runner.model, recurse=True)
|
||||||
|
|
||||||
|
hqq_runner.model.to("cpu", torch.bfloat16)
|
||||||
|
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||||
|
check_forward(self, hqq_runner.model)
|
||||||
|
|
||||||
|
hqq_runner.model.cuda(original_device)
|
||||||
|
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||||
|
check_forward(self, hqq_runner.model)
|
||||||
|
|
||||||
|
def test_quantized_model_fake_weight_dtype(self):
|
||||||
|
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||||
|
|
||||||
|
hqq_runner = HQQLLMRunner(
|
||||||
|
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use a hack to inject a fake weight to HQQLinear. Check that it works
|
||||||
|
self.assertEqual(hqq_runner.model.model.layers[0].self_attn.v_proj.weight.dtype, torch.float16)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
Loading…
Reference in New Issue
Block a user