mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 19:48:23 +06:00
Fix hqq skipped modules and dynamic quant (#36821)
* Fix hqq skip_modules and dynamic_quant * fix skipped modules loading * add dynamic/skip HqqConfig test
This commit is contained in:
parent
055afdb6bb
commit
3e8f0fbf44
@ -124,7 +124,14 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
||||||
_valid_modules = set()
|
_valid_modules = set()
|
||||||
_find_hqq_quantizable_layers(model, _valid_modules)
|
_find_hqq_quantizable_layers(model, _valid_modules)
|
||||||
_valid_modules -= set(model.config.quantization_config["skip_modules"])
|
|
||||||
|
# Remove skipped modules
|
||||||
|
_skipped_modules = set()
|
||||||
|
for _module in _valid_modules:
|
||||||
|
for _skip_module in model.config.quantization_config["skip_modules"]:
|
||||||
|
if _skip_module in _module:
|
||||||
|
_skipped_modules.add(_module)
|
||||||
|
_valid_modules -= _skipped_modules
|
||||||
|
|
||||||
# Append new expected layers based on _ref_keys
|
# Append new expected layers based on _ref_keys
|
||||||
_ref_keys = HQQLinear(
|
_ref_keys = HQQLinear(
|
||||||
@ -243,10 +250,24 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
|
|
||||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||||
# directly doesn't work.
|
# directly doesn't work.
|
||||||
if hasattr(module, "quant_config"):
|
quant_config = model.config.quantization_config["quant_config"]
|
||||||
|
skip_modules = model.config.quantization_config["skip_modules"]
|
||||||
|
module_tag = ".".join(module.name.split(".")[-2:])
|
||||||
|
module_quant_config = None
|
||||||
|
if "weight_quant_params" in quant_config:
|
||||||
|
module_quant_config = quant_config
|
||||||
|
elif module_tag in quant_config:
|
||||||
|
module_quant_config = quant_config[module_tag]
|
||||||
|
|
||||||
|
for skip_module in skip_modules:
|
||||||
|
if skip_module in module.name:
|
||||||
|
module_quant_config = None
|
||||||
|
break
|
||||||
|
|
||||||
|
if module_quant_config is not None:
|
||||||
hqq_layer = HQQLinear(
|
hqq_layer = HQQLinear(
|
||||||
module,
|
module,
|
||||||
module.quant_config,
|
quant_config=module_quant_config,
|
||||||
compute_dtype=self.torch_dtype,
|
compute_dtype=self.torch_dtype,
|
||||||
device=target_device,
|
device=target_device,
|
||||||
del_orig=True,
|
del_orig=True,
|
||||||
|
@ -207,3 +207,36 @@ class HQQSerializationTest(unittest.TestCase):
|
|||||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||||
|
|
||||||
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||||
|
|
||||||
|
def test_model_serialization_dynamic_quant_with_skip(self):
|
||||||
|
"""
|
||||||
|
Simple HQQ LLM save/load test with dynamic quant
|
||||||
|
"""
|
||||||
|
q4_config = {"nbits": 4, "group_size": 64}
|
||||||
|
q3_config = {"nbits": 3, "group_size": 64}
|
||||||
|
|
||||||
|
quant_config = HqqConfig(
|
||||||
|
dynamic_config={
|
||||||
|
"self_attn.q_proj": q4_config,
|
||||||
|
"self_attn.k_proj": q4_config,
|
||||||
|
"self_attn.v_proj": q4_config,
|
||||||
|
"self_attn.o_proj": q4_config,
|
||||||
|
"mlp.gate_proj": q3_config,
|
||||||
|
"mlp.up_proj": q3_config,
|
||||||
|
},
|
||||||
|
skip_modules=["lm_head", "down_proj"],
|
||||||
|
)
|
||||||
|
|
||||||
|
hqq_runner = HQQLLMRunner(
|
||||||
|
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
model = hqq_runner.model
|
||||||
|
|
||||||
|
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
model.forward(input_tensor).logits
|
||||||
|
|
||||||
|
self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True)
|
||||||
|
self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4)
|
||||||
|
self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)
|
||||||
|
Loading…
Reference in New Issue
Block a user