Fixes hqq by following a new path for bias parameter in pre_quantized models (#37530)

* fix

* add test
This commit is contained in:
Mohamed Mekkouri 2025-04-16 13:58:14 +02:00 committed by GitHub
parent 7dafcd0077
commit 7752e7487c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 6 deletions

View File

@ -170,16 +170,13 @@ class HqqHfQuantizer(HfQuantizer):
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
return (
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
and tensor_name != "weight"
and tensor_name != "bias"
)
return (isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear)) and tensor_name != "weight"
else:
# we need a special path for bias since hqq overwrote load_state_dict for this layer
return (
isinstance(module, torch.nn.Linear)
and tensor_name == "weight"
# bias doesn't need to be quantized, we use this as a workaround to avoid loading bias into HQQLinear assuming it was loaded
# in the state_dict directly with the weight because hqq overwrote load_state_dict for this layer
or (isinstance(module, HQQLinear) and tensor_name == "bias")
)

View File

@ -165,6 +165,39 @@ class HQQTestBias(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.decoder.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_save_and_load_quantized_model(self):
"""
Test saving and loading a quantized model with bias
"""
import tempfile
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id="facebook/opt-125m", quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
# Get reference logits
with torch.no_grad():
logits_ref = hqq_runner.model.forward(input_tensor).logits
with tempfile.TemporaryDirectory() as tmpdirname:
hqq_runner.model.save_pretrained(tmpdirname)
del hqq_runner.model
torch.cuda.empty_cache()
model_loaded = AutoModelForCausalLM.from_pretrained(
tmpdirname, torch_dtype=torch.float16, device_map=torch_device
)
with torch.no_grad():
logits_loaded = model_loaded.forward(input_tensor).logits
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
@slow
@require_torch_gpu