Allow compressed-tensors quantized model to be trained (#34520)

* populate quantization_config for kv-cache-scheme only configs

* make compressed-tensors quantized models trainable

* populate versions on quant config

* pass oneshot then finetune

* remove breakpoint

* SunMarc comments and fix to_dict logic

* lint

* lint

* test

* comment

* comments'
This commit is contained in:
George 2024-11-28 09:05:16 -05:00 committed by GitHub
parent 44af935ec5
commit 57ca9e6d2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 12 deletions

View File

@ -226,6 +226,11 @@ class HfQuantizer(ABC):
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)
@property
def is_qat_trainable(self) -> bool:
"""Flag indicating whether the quantized model can carry out quantization aware training"""
return False
@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs): ...

View File

@ -65,12 +65,19 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
ct_quantization_config = self.compressor.quantization_config
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
def _process_model_after_weight_loading(self, model, **kwargs):
def _process_model_after_weight_loading(self, model, **kwargs) -> None:
pass
@property
def is_trainable(self):
return False
def is_trainable(self) -> bool:
"""Models quantized using compressed tensors can be finetuned"""
return True
def is_serializable(self, safe_serialization=None):
return False
@property
def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
return True
def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
return True

View File

@ -540,6 +540,10 @@ class Trainer:
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)
_is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
model.hf_quantizer, "is_qat_trainable", False
)
# Filter out quantized + compiled models
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
raise ValueError(
@ -547,7 +551,7 @@ class Trainer:
)
# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"

View File

@ -1150,6 +1150,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
"""
if "quantization_config" in config_dict:
config_dict = dict(
sparsity_config=config_dict.get("sparsity_config"),
@ -1160,16 +1161,23 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
def to_dict(self) -> Dict[str, Any]:
"""
Quantization config to be added to config.json
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
quantization_config = self.quantization_config.dict() if self.quantization_config is not None else None
sparsity_config = self.sparsity_config.dict() if self.sparsity_config is not None else None
quantization_config = {}
if self.quantization_config is not None:
quantization_config = self.quantization_config.dict()
else:
quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS
return {
"quantization_config": quantization_config,
"sparsity_config": sparsity_config,
}
if self.sparsity_config is not None:
quantization_config["sparsity_config"] = self.sparsity_config.dict()
else:
quantization_config["sparsity_config"] = {}
return quantization_config
def to_diff_dict(self) -> Dict[str, Any]:
"""