From cb0f60419231ebec83b9ced356bb47a74a4009ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E6=A2=A6?= Date: Wed, 18 Jun 2025 21:09:00 +0800 Subject: [PATCH] Fix HQQ model param device transfer issue (#38466) * Fix HQQ model param device transfer issue * modify a comment * clear the code and add test for hqq device/dtype * fix test hqq code quality of imports --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_utils.py | 43 ++++++++++++++++++-- src/transformers/quantizers/quantizer_hqq.py | 9 ++++ tests/quantization/hqq/test_hqq.py | 37 +++++++++++++++++ 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ae6f194a90b..3b235d9aeac 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3897,7 +3897,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: - raise ValueError("`.cuda` is not supported for HQQ-quantized models.") + from hqq.core.quantize import HQQLinear + + # Since HQQLinear stores some tensors in the 'meta' attribute, + # it's necessary to manually call the `cuda` method on HQQLinear layers. + super().cuda(*args, **kwargs) + for module in self.modules(): + if isinstance(module, HQQLinear): + if len(args) > 0: + device = args[0] + else: + device = kwargs.get("device", "cuda") + module.cuda(device) + return self + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if getattr(self, "is_loaded_in_8bit", False): @@ -3910,8 +3923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) - else: - return super().cuda(*args, **kwargs) + return super().cuda(*args, **kwargs) @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): @@ -3926,7 +3938,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi break if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: - raise ValueError("`.to` is not supported for HQQ-quantized models.") + from hqq.core.quantize import HQQLinear + + # Since HQQLinear stores some tensors in the 'meta' attribute, we must + # explicitly move the parameters to the target device for each HQQLinear layer after `to`. + super().to(*args, **kwargs) + for module in self.modules(): + if isinstance(module, HQQLinear): + if "device" in kwargs: + device = kwargs["device"] + else: + device = args[0] + if "dtype" in kwargs: + dtype = kwargs["dtype"] + elif dtype_present_in_args: + dtype = arg + else: + dtype = None + # Due to the current messy implementation of HQQLinear, updating `compute_dtype` + # followed by calling the `cuda` method achieves the intended behavior of `to`, + # even when the target device is CPU. + if dtype is not None: + module.compute_dtype = dtype + module.cuda(device) + return self if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK: raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.") diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index a0cc0170e18..6061c72c249 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -202,6 +202,15 @@ class HqqHfQuantizer(HfQuantizer): if is_hqq_available(): from hqq.core.quantize import HQQLinear + # TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute, + # but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors, + # we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device. + @property + def weight(_self: HQQLinear): + return torch.empty(0, dtype=_self.compute_dtype, device=_self.device) + + HQQLinear.weight = weight + module, tensor_name = get_module_from_name(model, param_name) layer_name = ".".join(param_name.split(".")[:-1]) parent_module = find_parent(model, layer_name) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 877f6a2cd8d..37d91e9a259 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -15,6 +15,8 @@ import gc import unittest +import accelerate + from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig from transformers.testing_utils import ( backend_empty_cache, @@ -119,6 +121,41 @@ class HQQTest(unittest.TestCase): check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_forward(self, hqq_runner.model) + def test_quantized_model_to_new_device_and_new_dtype(self): + """ + Simple LLM model testing different devices and dtypes + """ + quant_config = HqqConfig(nbits=8, group_size=64) + + hqq_runner = HQQLLMRunner( + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device + ) + + original_device = hqq_runner.model.model.layers[0].self_attn.v_proj.device + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) + + # Remove `accelerate` hooks to enable move the model to a new device + accelerate.hooks.remove_hook_from_module(hqq_runner.model, recurse=True) + + hqq_runner.model.to("cpu", torch.bfloat16) + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) + + hqq_runner.model.cuda(original_device) + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) + + def test_quantized_model_fake_weight_dtype(self): + quant_config = HqqConfig(nbits=8, group_size=64) + + hqq_runner = HQQLLMRunner( + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device + ) + + # We use a hack to inject a fake weight to HQQLinear. Check that it works + self.assertEqual(hqq_runner.model.model.layers[0].self_attn.v_proj.weight.dtype, torch.float16) + @slow @require_torch_gpu