From 7752e7487c3ae675ea79f54cbb056f8e90446d59 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Wed, 16 Apr 2025 13:58:14 +0200 Subject: [PATCH] Fixes hqq by following a new path for bias parameter in pre_quantized models (#37530) * fix * add test --- src/transformers/quantizers/quantizer_hqq.py | 9 ++---- tests/quantization/hqq/test_hqq.py | 33 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 4adc323f958..38d8a15cbfc 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -170,16 +170,13 @@ class HqqHfQuantizer(HfQuantizer): module, tensor_name = get_module_from_name(model, param_name) if self.pre_quantized: - return ( - (isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear)) - and tensor_name != "weight" - and tensor_name != "bias" - ) + return (isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear)) and tensor_name != "weight" else: - # we need a special path for bias since hqq overwrote load_state_dict for this layer return ( isinstance(module, torch.nn.Linear) and tensor_name == "weight" + # bias doesn't need to be quantized, we use this as a workaround to avoid loading bias into HQQLinear assuming it was loaded + # in the state_dict directly with the weight because hqq overwrote load_state_dict for this layer or (isinstance(module, HQQLinear) and tensor_name == "bias") ) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 221248d7884..a686bbd7de7 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -165,6 +165,39 @@ class HQQTestBias(unittest.TestCase): check_hqqlayer(self, hqq_runner.model.model.decoder.layers[0].self_attn.v_proj) check_forward(self, hqq_runner.model) + def test_save_and_load_quantized_model(self): + """ + Test saving and loading a quantized model with bias + """ + import tempfile + + quant_config = HqqConfig(nbits=8, group_size=64) + + hqq_runner = HQQLLMRunner( + model_id="facebook/opt-125m", quant_config=quant_config, compute_dtype=torch.float16, device=torch_device + ) + + input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device) + + # Get reference logits + with torch.no_grad(): + logits_ref = hqq_runner.model.forward(input_tensor).logits + + with tempfile.TemporaryDirectory() as tmpdirname: + hqq_runner.model.save_pretrained(tmpdirname) + + del hqq_runner.model + torch.cuda.empty_cache() + + model_loaded = AutoModelForCausalLM.from_pretrained( + tmpdirname, torch_dtype=torch.float16, device_map=torch_device + ) + + with torch.no_grad(): + logits_loaded = model_loaded.forward(input_tensor).logits + + self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0) + @slow @require_torch_gpu