Improve model loading for compressed tensor models (#36152)

* Disable warnings for stacked compressors
* Introduce two new hooks in HfQuantizer lifecycle
to allow updates to missing and unexpected keys
* Update missing and unexpected keys
for stacked compressors
* Add tests
* Fix: run_compressed cases
* Fix: uncompressed cases

* Rename compressed_tensor folder to compressed_tensors
Move RunCompressedTest to the same file
Update tests to unittest
This commit is contained in:
Rahul Tuli 2025-02-24 06:47:21 -06:00 committed by GitHub
parent 4dbf17c17f
commit 884a8ea1f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 307 additions and 176 deletions

View File

@ -4673,6 +4673,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix)
# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
@ -4993,6 +4994,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys_after_loading(model_to_load, missing_keys, prefix)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:

View File

@ -109,6 +109,27 @@ class HfQuantizer(ABC):
"""
return missing_keys
def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `unexpected_keys`.
Args:
unexpected_keys (`List[str]`, *optional*):
The list of unexpected keys in the checkpoint compared to the state dict of the model
"""
return unexpected_keys
def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `missing_keys` after loading the model params,
but before the model is post-processed.
Args:
missing_keys (`List[str]`, *optional*):
The list of missing keys in the checkpoint compared to the state dict of the model
"""
return missing_keys
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.

View File

@ -14,6 +14,8 @@
import os
import re
from typing import List
from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import CompressedTensorsConfig
@ -50,6 +52,45 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
self.run_compressed = quantization_config.run_compressed
self.quantization_config = quantization_config
def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
Update missing keys after loading the model. This is necessary for compressed tensors
to load the model correctly. We expect weights to be present in missing keys.
The weight's are re-constructed by ModelCompressor in _process_model_after_weight_loading
This function cleans up expected missing keys and returns the remaining missing keys
"""
if self.run_compressed:
return missing_keys
# We expect some keys to be missing for
# compresed models
# This is fine as the weights are reconstructed by ModelCompressor
# in _process_model_after_weight_loading
expected_missing_keys = self.compressor.get_missing_module_keys(model)
return [
key for key in missing_keys if not any(re.match(f".*{pattern}", key) for pattern in expected_missing_keys)
]
def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `unexpected_keys`.
Args:
unexpected_keys (`List[str]`, *optional*):
The list of unexpected keys in the checkpoint compared to the state dict of the model
"""
if self.run_compressed:
return unexpected_keys
# We expect some unexpected keys in model
# safetensors file for compressed models
keys_to_ignore = self.compressor.get_unexpected_file_keys(model)
return [key for key in unexpected_keys if not any(re.match(f".*{pattern}", key) for pattern in keys_to_ignore)]
def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
raise ImportError(
@ -75,9 +116,11 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
ct_quantization_config = self.compressor.quantization_config
if self.run_compressed and self.is_quantization_compressed:
if self.run_compressed:
if not self.is_quantization_compressed:
raise ValueError("`run_compressed` is only supported for quantized_compressed models")
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
elif not self.is_quantization_compressed:
elif self.is_quantized and not self.is_quantization_compressed:
apply_quantization_config(model, ct_quantization_config)
def _process_model_after_weight_loading(self, model, **kwargs):
@ -99,6 +142,12 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)
@property
def is_quantized(self):
return self.quantization_config.quantization_config is not None and bool(
self.quantization_config.quantization_config.config_groups
)
@property
def is_quantization_compressed(self):
from compressed_tensors.quantization import QuantizationStatus

View File

@ -1,80 +0,0 @@
import gc
import unittest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available
if is_torch_available():
import torch
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
model_sparse_uncompressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_uncompressed"
model_sparse_compressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_compressed"
prompt = "Paris is the capital of which country?"
stubs = [model_sparse_uncompressed, model_sparse_compressed]
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_compressed_uncompressed_model_shapes(self):
"""
Check that the weights are the same between
uncompressed and compressed-decompressed model
Sparse compressed modules' weights are "packed" and shape/value will
differ
"""
def _has_nested_attr(obj, attr_path):
attrs = attr_path.split(".")
for attr in attrs:
if not hasattr(obj, attr):
return None
obj = getattr(obj, attr)
return obj
from compressed_tensors.quantization.utils import iter_named_leaf_modules
uncompressed_model = AutoModelForCausalLM.from_pretrained(
self.model_sparse_uncompressed,
)
compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
self.model_sparse_compressed,
)
for name, submodule in iter_named_leaf_modules(
uncompressed_model,
):
if comp_decomp_obj := _has_nested_attr(compressed_model_decompressed, name):
if hasattr(submodule, "weight"):
assert torch.equal(submodule.weight, comp_decomp_obj.weight)
def test_run_compressed_outputs_match(self):
"""Check that uncompressed and compressed-decompressed model outputs are the same"""
from transformers import AutoTokenizer
for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
uncompressed_model = AutoModelForCausalLM.from_pretrained(
self.model_sparse_uncompressed,
)
output_rc_true = uncompressed_model.generate(input_ids, max_new_tokens=100)
compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
self.model_sparse_compressed,
)
output_rc_false = compressed_model_decompressed.generate(input_ids, max_new_tokens=100)
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])

View File

@ -1,94 +0,0 @@
import gc
import unittest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available
if is_torch_available():
import torch
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"
prompt = "Paris is the capital of which country?"
stubs = [tinyllama_w4a16, tinyllama_w8a8]
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_default_run_compressed__True(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# some linear models are not compressed - ex. lm_head
assert compressed_linear_counts > 0
def test_default_run_compressed__False(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# No modules should be CompressedLinear
assert compressed_linear_counts == 0
def test_run_compressed_outputs_match(self):
"""Check that run_compressed=True/False output are the same"""
from transformers import AutoTokenizer
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
stub,
)
output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)
model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])

View File

@ -0,0 +1,231 @@
import gc
import unittest
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available
from transformers.utils.quantization_config import CompressedTensorsConfig
if is_torch_available():
import torch
@require_compressed_tensors
@require_torch
class StackCompressedModelTest(unittest.TestCase):
# Define stubs as class attributes
compressed_uncompressed_model_stubs = [
(
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
),
(
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
),
(
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
),
]
# Flatten the list for tests that require a single list of stubs.
model_stubs = [stub for pair in compressed_uncompressed_model_stubs for stub in pair]
# For the outputs matching test, use the sparse-only pair.
sparse_compressed_model = "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed"
sparse_uncompressed_model = "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
prompt = "Paris is the capital of which country?"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_compressed_uncompressed_model_shapes(self):
"""
Verify that the weights of an uncompressed model and its decompressed compressed counterpart match.
Note: Weights for sparsely compressed models may differ due to packing.
"""
def _has_nested_attr(obj, attr_path):
attrs = attr_path.split(".")
for attr in attrs:
if not hasattr(obj, attr):
return None
obj = getattr(obj, attr)
return obj
from compressed_tensors.quantization.utils import iter_named_leaf_modules
for compressed_model, uncompressed_model in self.compressed_uncompressed_model_stubs:
with self.subTest(compressed_model=compressed_model, uncompressed_model=uncompressed_model):
uncompressed = AutoModelForCausalLM.from_pretrained(
uncompressed_model,
device_map="auto",
torch_dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
compressed_decompressed = AutoModelForCausalLM.from_pretrained(
compressed_model,
device_map="auto",
torch_dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
for name, submodule in iter_named_leaf_modules(uncompressed):
comp_decomp_obj = _has_nested_attr(compressed_decompressed, name)
if comp_decomp_obj is not None and hasattr(submodule, "weight"):
if "sparse-only" in uncompressed_model:
self.assertTrue(
torch.equal(submodule.weight, comp_decomp_obj.weight),
f"Weight mismatch for module '{name}' in sparse-only model.",
)
else:
self.assertTrue(
torch.allclose(submodule.weight, comp_decomp_obj.weight, atol=0.2),
f"Weight mismatch for module '{name}' in quantized-only or stacked model.",
)
def test_outputs_match(self):
"""
Ensure that the generated outputs match between the uncompressed model
and its decompressed compressed counterpart.
"""
tokenizer = AutoTokenizer.from_pretrained(self.sparse_uncompressed_model)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
uncompressed = AutoModelForCausalLM.from_pretrained(
self.sparse_uncompressed_model,
device_map="auto",
torch_dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
output_uncompressed = uncompressed.generate(input_ids.to(uncompressed.device), max_new_tokens=100)
decompressed = AutoModelForCausalLM.from_pretrained(
self.sparse_compressed_model,
device_map="auto",
torch_dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
output_decompressed = decompressed.generate(input_ids.to(decompressed.device), max_new_tokens=100)
self.assertEqual(
tokenizer.decode(output_uncompressed[0]),
tokenizer.decode(output_decompressed[0]),
"Generated outputs do not match between compressed and uncompressed models.",
)
def test_no_warnings_for_all_models(self):
"""
Confirm that loading any model using compressed tensors does not trigger
warnings about missing or unexpected keys.
"""
for model_stub in self.model_stubs:
with self.subTest(model_stub=model_stub):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
AutoModelForCausalLM.from_pretrained(
model_stub,
device_map="auto",
torch_dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
for warning in caught_warnings:
self.assertNotIn(
"missing keys",
str(warning.message).lower(),
f"'missing keys' found in warnings for model {model_stub}",
)
self.assertNotIn(
"unexpected keys",
str(warning.message).lower(),
f"'unexpected keys' found in warnings for model {model_stub}",
)
@require_compressed_tensors
@require_torch
class RunCompressedTest(unittest.TestCase):
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"
prompt = "Paris is the capital of which country?"
stubs = [tinyllama_w4a16, tinyllama_w8a8]
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_default_run_compressed__True(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# some linear models are not compressed - ex. lm_head
assert compressed_linear_counts > 0
def test_default_run_compressed__False(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# No modules should be CompressedLinear
assert compressed_linear_counts == 0
def test_run_compressed_outputs_match(self):
"""Check that run_compressed=True/False output are the same"""
from transformers import AutoTokenizer
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
stub,
)
output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)
model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])