Fix: Unexpected Keys, Improve run_compressed, Rename Test Folder (#37077)

This commit is contained in:
Rahul Tuli 2025-04-04 14:30:11 -05:00 committed by GitHub
parent 531e4fcf0e
commit ebe47ce3e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 48 additions and 32 deletions

View File

@ -1352,6 +1352,7 @@ def _find_missing_and_unexpected_keys(
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix)
# Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...)
if cls._keys_to_ignore_on_load_missing is not None:

View File

@ -46,6 +46,10 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
"`pip install compressed-tensors`"
)
# Call post_init here to ensure proper config setup when `run_compressed`
# is provided directly via CompressedTensorsConfig, and to avoid duplicate logging.
quantization_config.post_init()
from compressed_tensors.compressors import ModelCompressor
self.compressor = ModelCompressor.from_compression_config(quantization_config)
@ -117,16 +121,16 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
ct_quantization_config = self.compressor.quantization_config
if self.run_compressed:
if not self.is_quantization_compressed:
raise ValueError("`run_compressed` is only supported for quantized_compressed models")
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
elif self.is_quantized and not self.is_quantization_compressed:
elif not self.quantization_config.is_quantization_compressed:
apply_quantization_config(model, ct_quantization_config)
def _process_model_after_weight_loading(self, model, **kwargs):
"""Decompress loaded model if necessary - need for qat"""
if (self.is_quantization_compressed and not self.run_compressed) or self.is_sparsification_compressed:
if (
self.quantization_config.is_quantization_compressed and not self.run_compressed
) or self.quantization_config.is_sparsification_compressed:
config = kwargs.get("config", None)
cache_path = config._name_or_path
@ -136,36 +140,12 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
config_file_path = cached_file(cache_path, "config.json")
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
if self.is_quantization_compressed and not self.run_compressed:
if self.quantization_config.is_quantization_compressed and not self.run_compressed:
from compressed_tensors.quantization import QuantizationStatus
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)
@property
def is_quantized(self):
return self.quantization_config.quantization_config is not None and bool(
self.quantization_config.quantization_config.config_groups
)
@property
def is_quantization_compressed(self):
from compressed_tensors.quantization import QuantizationStatus
return (
self.quantization_config.quantization_config is not None
and self.quantization_config.quantization_config.quantization_status == QuantizationStatus.COMPRESSED
)
@property
def is_sparsification_compressed(self):
from compressed_tensors.config.base import CompressionFormat
return (
self.quantization_config.sparsity_config is not None
and self.quantization_config.sparsity_config.format != CompressionFormat.dense.value
)
@property
def is_trainable(self):
return True
@ -173,7 +153,7 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
# models need to be decompressed carry out qat
return not self.run_compressed or not self.is_quantization_compressed
return not self.run_compressed or not self.quantization_config.is_quantization_compressed
def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""

View File

@ -1263,7 +1263,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
# parse from dict to load nested QuantizationScheme objects
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.parse_obj(
self.quantization_config = QuantizationConfig.model_validate(
{
"config_groups": config_groups,
"quant_method": quant_method,
@ -1282,7 +1282,19 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
sparsity_config.get("format"), **sparsity_config
)
super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS)
self.quant_method = QuantizationMethod.COMPRESSED_TENSORS
def post_init(self):
if self.run_compressed:
if self.is_sparsification_compressed:
logger.warn(
"`run_compressed` is only supported for quantized_compressed models"
" and not for sparsified models. Setting `run_compressed=False`"
)
self.run_compressed = False
elif not self.is_quantization_compressed:
logger.warn("`run_compressed` is only supported for compressed models. Setting `run_compressed=False`")
self.run_compressed = False
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
@ -1356,6 +1368,28 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
def get_loading_attributes(self):
return {"run_compressed": self.run_compressed}
@property
def is_quantized(self):
return bool(self.quantization_config) and bool(self.quantization_config.config_groups)
@property
def is_quantization_compressed(self):
from compressed_tensors.quantization import QuantizationStatus
return self.is_quantized and self.quantization_config.quantization_status == QuantizationStatus.COMPRESSED
@property
def is_sparsification_compressed(self):
from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
)
return (
isinstance(self.sparsity_config, SparsityCompressionConfig)
and self.sparsity_config.format != CompressionFormat.dense.value
)
@dataclass
class FbgemmFp8Config(QuantizationConfigMixin):

View File

@ -185,6 +185,7 @@ class RunCompressedTest(unittest.TestCase):
def test_default_run_compressed__False(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)