mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 8d9de7af15
into 37a239ca50
This commit is contained in:
commit
1925806e0a
@ -124,6 +124,21 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
|
||||
elif not self.quantization_config.is_quantization_compressed:
|
||||
apply_quantization_config(model, ct_quantization_config)
|
||||
|
||||
# Identify quantized modules with integer weights/activations
|
||||
quant_targets = set()
|
||||
if ct_quantization_config:
|
||||
for group in ct_quantization_config.config_groups.values():
|
||||
if group.weights.type.startswith("int") or (
|
||||
group.input_activations and group.input_activations.type.startswith("int")
|
||||
):
|
||||
quant_targets.update(group.targets)
|
||||
|
||||
# Disable gradient computation for quantized int modules
|
||||
for module in model.modules():
|
||||
if type(module).__name__ in quant_targets:
|
||||
for param in module.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
"""Decompress loaded model if necessary - need for qat"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user