mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix hqq skipped modules and dynamic quant (#36821)
* Fix hqq skip_modules and dynamic_quant * fix skipped modules loading * add dynamic/skip HqqConfig test
This commit is contained in:
parent
055afdb6bb
commit
3e8f0fbf44
@ -124,7 +124,14 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
||||
_valid_modules = set()
|
||||
_find_hqq_quantizable_layers(model, _valid_modules)
|
||||
_valid_modules -= set(model.config.quantization_config["skip_modules"])
|
||||
|
||||
# Remove skipped modules
|
||||
_skipped_modules = set()
|
||||
for _module in _valid_modules:
|
||||
for _skip_module in model.config.quantization_config["skip_modules"]:
|
||||
if _skip_module in _module:
|
||||
_skipped_modules.add(_module)
|
||||
_valid_modules -= _skipped_modules
|
||||
|
||||
# Append new expected layers based on _ref_keys
|
||||
_ref_keys = HQQLinear(
|
||||
@ -243,10 +250,24 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
|
||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||
# directly doesn't work.
|
||||
if hasattr(module, "quant_config"):
|
||||
quant_config = model.config.quantization_config["quant_config"]
|
||||
skip_modules = model.config.quantization_config["skip_modules"]
|
||||
module_tag = ".".join(module.name.split(".")[-2:])
|
||||
module_quant_config = None
|
||||
if "weight_quant_params" in quant_config:
|
||||
module_quant_config = quant_config
|
||||
elif module_tag in quant_config:
|
||||
module_quant_config = quant_config[module_tag]
|
||||
|
||||
for skip_module in skip_modules:
|
||||
if skip_module in module.name:
|
||||
module_quant_config = None
|
||||
break
|
||||
|
||||
if module_quant_config is not None:
|
||||
hqq_layer = HQQLinear(
|
||||
module,
|
||||
module.quant_config,
|
||||
quant_config=module_quant_config,
|
||||
compute_dtype=self.torch_dtype,
|
||||
device=target_device,
|
||||
del_orig=True,
|
||||
|
@ -207,3 +207,36 @@ class HQQSerializationTest(unittest.TestCase):
|
||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||
|
||||
def test_model_serialization_dynamic_quant_with_skip(self):
|
||||
"""
|
||||
Simple HQQ LLM save/load test with dynamic quant
|
||||
"""
|
||||
q4_config = {"nbits": 4, "group_size": 64}
|
||||
q3_config = {"nbits": 3, "group_size": 64}
|
||||
|
||||
quant_config = HqqConfig(
|
||||
dynamic_config={
|
||||
"self_attn.q_proj": q4_config,
|
||||
"self_attn.k_proj": q4_config,
|
||||
"self_attn.v_proj": q4_config,
|
||||
"self_attn.o_proj": q4_config,
|
||||
"mlp.gate_proj": q3_config,
|
||||
"mlp.up_proj": q3_config,
|
||||
},
|
||||
skip_modules=["lm_head", "down_proj"],
|
||||
)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
model = hqq_runner.model
|
||||
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||
with torch.no_grad():
|
||||
model.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True)
|
||||
self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4)
|
||||
self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)
|
||||
|
Loading…
Reference in New Issue
Block a user