Run model as compressed/uncompressed mode (#34719)

* draft, run model as compreszed/uncompressed mode

* draft

* run run_compressed=False

* run_compressed as attr

* set run_compressed=False using quantization_config

* remove redundant line

* make is_qat_trainable dependent on run_compressed status

* add tests

* lint

* full in docstring

* add decompress

* comments

* decompress if model is compresssed and not run_compressed

* apply_quant_config logic fix -- populate statedict properly

* comments

* remove non  compressed model

* make is_compressed as property

* cosmetic

* run apply_quant_config for non-compressed models -- popualte scales and zeropoints

* add pahtway for decompressing sparse models

* typo on is_quantization_compressed

* lint

* fix typo
This commit is contained in:
George 2024-12-13 02:23:31 -05:00 committed by GitHub
parent 31f9a289a6
commit e4e404fdd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 250 additions and 18 deletions

View File

@ -3597,7 +3597,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
else:
config.quantization_config = quantization_config
hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
hf_quantizer = AutoHfQuantizer.from_config(
config.quantization_config,
pre_quantized=pre_quantized,
)
else:
hf_quantizer = None
@ -4281,7 +4286,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dispatch_model(model, **device_map_kwargs)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer
if _adapter_model_path is not None:

View File

@ -173,13 +173,14 @@ class AutoHfQuantizer:
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
if (
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig))
and quantization_config_from_args is not None
):
# special case for GPTQ / AWQ / FbgemmFp8 config collision
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
for attr, val in loading_attr_dict.items():
setattr(quantization_config, attr, val)
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
if warning_msg != "":

View File

@ -111,7 +111,7 @@ class AwqQuantizer(HfQuantizer):
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
)
def _process_model_after_weight_loading(self, model):
def _process_model_after_weight_loading(self, model, **kwargs):
if self.quantization_config.do_fuse:
from ..integrations import fuse_awq_modules

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
from ..utils.quantization_config import CompressedTensorsConfig
from .base import HfQuantizer
@ -32,12 +35,13 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
requires_calibration = True
required_packages = ["compressed_tensors"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
super().__init__(quantization_config, **kwargs)
from compressed_tensors.compressors import ModelCompressor
self.compressor = ModelCompressor.from_compression_config(quantization_config)
self.run_compressed = quantization_config.run_compressed
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
@ -63,20 +67,57 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
from compressed_tensors.quantization import apply_quantization_config
ct_quantization_config = self.compressor.quantization_config
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
def _process_model_after_weight_loading(self, model, **kwargs) -> None:
pass
if self.run_compressed and self.is_quantization_compressed:
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
elif not self.is_quantization_compressed:
apply_quantization_config(model, ct_quantization_config)
def _process_model_after_weight_loading(self, model, **kwargs):
"""Decompress loaded model if necessary - need for qat"""
if (self.is_quantization_compressed and not self.run_compressed) or self.is_sparsification_compressed:
config = kwargs.get("config", None)
cache_path = config._name_or_path
if not os.path.exists(cache_path):
from transformers.utils import cached_file
config_file_path = cached_file(cache_path, "config.json")
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
if self.is_quantization_compressed and not self.run_compressed:
from compressed_tensors.quantization import QuantizationStatus
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)
@property
def is_trainable(self) -> bool:
"""Models quantized using compressed tensors can be finetuned"""
def is_quantization_compressed(self):
from compressed_tensors.quantization import QuantizationStatus
return (
self.quantization_config.quantization_config is not None
and self.quantization_config.quantization_config.quantization_status == QuantizationStatus.COMPRESSED
)
@property
def is_sparsification_compressed(self):
from compressed_tensors.config.base import CompressionFormat
return (
self.quantization_config.sparsity_config is not None
and self.quantization_config.sparsity_config.format != CompressionFormat.dense.value
)
@property
def is_trainable(self):
return True
@property
def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
return True
# models need to be decompressed carry out qat
return not self.run_compressed or not self.is_quantization_compressed
def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""

View File

@ -197,7 +197,7 @@ class QuantoHfQuantizer(HfQuantizer):
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model):
def _process_model_after_weight_loading(self, model, **kwargs):
return model
@property

View File

@ -195,7 +195,7 @@ class TorchAoHfQuantizer(HfQuantizer):
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_after_weight_loading(self, model):
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
return

View File

@ -1077,7 +1077,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
dictionary mapping group name to a quantization scheme definition
format (`str`, *optional*, defaults to `"dense"`):
format the model is represented as
format the model is represented as. Set `run_compressed` True to execute model as the
compressed format if not `dense`
quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
@ -1090,6 +1091,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
configuration for sparsity compression
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
do not override, should be compressed-tensors
run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
emulate compressed model execution if True, otherwise use default submodule
"""
def __init__(
@ -1102,14 +1105,17 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
ignore: Optional[List[str]] = None,
sparsity_config: Dict[str, Any] = None,
quant_method: str = "compressed-tensors",
run_compressed: bool = True,
**kwargs,
):
from compressed_tensors import QuantizationConfig
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
self.quantization_config = None
self.sparsity_config = None
self.run_compressed = run_compressed
# parse from dict to load nested QuantizationScheme objects
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.parse_obj(
@ -1121,6 +1127,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
"kv_cache_scheme": kv_cache_scheme,
"global_compression_ratio": global_compression_ratio,
"ignore": ignore,
"run_compressed": run_compressed,
**kwargs,
}
)
@ -1149,6 +1156,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
"""
if "quantization_config" in config_dict:
@ -1200,6 +1208,9 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
return serializable_config_dict
def get_loading_attributes(self):
return {"run_compressed": self.run_compressed}
@dataclass
class FbgemmFp8Config(QuantizationConfigMixin):

View File

@ -0,0 +1,80 @@
import gc
import unittest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available
if is_torch_available():
import torch
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
model_sparse_uncompressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_uncompressed"
model_sparse_compressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_compressed"
prompt = "Paris is the capital of which country?"
stubs = [model_sparse_uncompressed, model_sparse_compressed]
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_compressed_uncompressed_model_shapes(self):
"""
Check that the weights are the same between
uncompressed and compressed-decompressed model
Sparse compressed modules' weights are "packed" and shape/value will
differ
"""
def _has_nested_attr(obj, attr_path):
attrs = attr_path.split(".")
for attr in attrs:
if not hasattr(obj, attr):
return None
obj = getattr(obj, attr)
return obj
from compressed_tensors.quantization.utils import iter_named_leaf_modules
uncompressed_model = AutoModelForCausalLM.from_pretrained(
self.model_sparse_uncompressed,
)
compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
self.model_sparse_compressed,
)
for name, submodule in iter_named_leaf_modules(
uncompressed_model,
):
if comp_decomp_obj := _has_nested_attr(compressed_model_decompressed, name):
if hasattr(submodule, "weight"):
assert torch.equal(submodule.weight, comp_decomp_obj.weight)
def test_run_compressed_outputs_match(self):
"""Check that uncompressed and compressed-decompressed model outputs are the same"""
from transformers import AutoTokenizer
for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
uncompressed_model = AutoModelForCausalLM.from_pretrained(
self.model_sparse_uncompressed,
)
output_rc_true = uncompressed_model.generate(input_ids, max_new_tokens=100)
compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
self.model_sparse_compressed,
)
output_rc_false = compressed_model_decompressed.generate(input_ids, max_new_tokens=100)
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])

View File

@ -0,0 +1,94 @@
import gc
import unittest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available
if is_torch_available():
import torch
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"
prompt = "Paris is the capital of which country?"
stubs = [tinyllama_w4a16, tinyllama_w8a8]
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_default_run_compressed__True(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# some linear models are not compressed - ex. lm_head
assert compressed_linear_counts > 0
def test_default_run_compressed__False(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# No modules should be CompressedLinear
assert compressed_linear_counts == 0
def test_run_compressed_outputs_match(self):
"""Check that run_compressed=True/False output are the same"""
from transformers import AutoTokenizer
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
stub,
)
output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)
model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])