Fix hqq skipped modules and dynamic quant (#36821)

* Fix hqq skip_modules and dynamic_quant

* fix skipped modules loading

* add dynamic/skip HqqConfig test
This commit is contained in:
mobicham 2025-03-20 15:31:49 +01:00 committed by GitHub
parent 055afdb6bb
commit 3e8f0fbf44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 3 deletions

View File

@ -124,7 +124,14 @@ class HqqHfQuantizer(HfQuantizer):
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
_valid_modules = set()
_find_hqq_quantizable_layers(model, _valid_modules)
_valid_modules -= set(model.config.quantization_config["skip_modules"])
# Remove skipped modules
_skipped_modules = set()
for _module in _valid_modules:
for _skip_module in model.config.quantization_config["skip_modules"]:
if _skip_module in _module:
_skipped_modules.add(_module)
_valid_modules -= _skipped_modules
# Append new expected layers based on _ref_keys
_ref_keys = HQQLinear(
@ -243,10 +250,24 @@ class HqqHfQuantizer(HfQuantizer):
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.
if hasattr(module, "quant_config"):
quant_config = model.config.quantization_config["quant_config"]
skip_modules = model.config.quantization_config["skip_modules"]
module_tag = ".".join(module.name.split(".")[-2:])
module_quant_config = None
if "weight_quant_params" in quant_config:
module_quant_config = quant_config
elif module_tag in quant_config:
module_quant_config = quant_config[module_tag]
for skip_module in skip_modules:
if skip_module in module.name:
module_quant_config = None
break
if module_quant_config is not None:
hqq_layer = HQQLinear(
module,
module.quant_config,
quant_config=module_quant_config,
compute_dtype=self.torch_dtype,
device=target_device,
del_orig=True,

View File

@ -207,3 +207,36 @@ class HQQSerializationTest(unittest.TestCase):
logits_loaded = model_loaded.forward(input_tensor).logits
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
def test_model_serialization_dynamic_quant_with_skip(self):
"""
Simple HQQ LLM save/load test with dynamic quant
"""
q4_config = {"nbits": 4, "group_size": 64}
q3_config = {"nbits": 3, "group_size": 64}
quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
},
skip_modules=["lm_head", "down_proj"],
)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
model = hqq_runner.model
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
with torch.no_grad():
model.forward(input_tensor).logits
self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True)
self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4)
self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)