Fix HQQ model param device transfer issue (#38466)

* Fix HQQ model param device transfer issue

* modify a comment

* clear the code and add test for hqq device/dtype

* fix test hqq code quality of imports

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
艾梦 2025-06-18 21:09:00 +08:00 committed by GitHub
parent c77bcd889f
commit cb0f604192
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 4 deletions

View File

@ -3897,7 +3897,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
from hqq.core.quantize import HQQLinear
# Since HQQLinear stores some tensors in the 'meta' attribute,
# it's necessary to manually call the `cuda` method on HQQLinear layers.
super().cuda(*args, **kwargs)
for module in self.modules():
if isinstance(module, HQQLinear):
if len(args) > 0:
device = args[0]
else:
device = kwargs.get("device", "cuda")
module.cuda(device)
return self
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
@ -3910,8 +3923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
else:
return super().cuda(*args, **kwargs)
return super().cuda(*args, **kwargs)
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
@ -3926,7 +3938,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
break
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
from hqq.core.quantize import HQQLinear
# Since HQQLinear stores some tensors in the 'meta' attribute, we must
# explicitly move the parameters to the target device for each HQQLinear layer after `to`.
super().to(*args, **kwargs)
for module in self.modules():
if isinstance(module, HQQLinear):
if "device" in kwargs:
device = kwargs["device"]
else:
device = args[0]
if "dtype" in kwargs:
dtype = kwargs["dtype"]
elif dtype_present_in_args:
dtype = arg
else:
dtype = None
# Due to the current messy implementation of HQQLinear, updating `compute_dtype`
# followed by calling the `cuda` method achieves the intended behavior of `to`,
# even when the target device is CPU.
if dtype is not None:
module.compute_dtype = dtype
module.cuda(device)
return self
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")

View File

@ -202,6 +202,15 @@ class HqqHfQuantizer(HfQuantizer):
if is_hqq_available():
from hqq.core.quantize import HQQLinear
# TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
@property
def weight(_self: HQQLinear):
return torch.empty(0, dtype=_self.compute_dtype, device=_self.device)
HQQLinear.weight = weight
module, tensor_name = get_module_from_name(model, param_name)
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)

View File

@ -15,6 +15,8 @@
import gc
import unittest
import accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
from transformers.testing_utils import (
backend_empty_cache,
@ -119,6 +121,41 @@ class HQQTest(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_quantized_model_to_new_device_and_new_dtype(self):
"""
Simple LLM model testing different devices and dtypes
"""
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
original_device = hqq_runner.model.model.layers[0].self_attn.v_proj.device
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
# Remove `accelerate` hooks to enable move the model to a new device
accelerate.hooks.remove_hook_from_module(hqq_runner.model, recurse=True)
hqq_runner.model.to("cpu", torch.bfloat16)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
hqq_runner.model.cuda(original_device)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_quantized_model_fake_weight_dtype(self):
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
# We use a hack to inject a fake weight to HQQLinear. Check that it works
self.assertEqual(hqq_runner.model.model.layers[0].self_attn.v_proj.weight.dtype, torch.float16)
@slow
@require_torch_gpu