mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00

* update HQQ transformers integration * push import_utils.py * add force_hooks check in modeling_utils.py * fix | with Optional * force bias as param * check bias is Tensor * force forward for multi-gpu * review fixes pass * remove torch grad() * if any key in linear_tags fix * add cpu/disk check * isinstance return * add multigpu test + refactor tests * clean hqq_utils imports in hqq.py * clean hqq_utils imports in quantizer_hqq.py * delete hqq_utils.py * Delete src/transformers/utils/hqq_utils.py * ruff init * remove torch.float16 from __init__ in test * refactor test * isinstance -> type in quantizer_hqq.py * cpu/disk device_map check in quantizer_hqq.py * remove type(module) nn.linear check in quantizer_hqq.py * add BaseQuantizeConfig import inside HqqConfig init * remove hqq import in hqq.py * remove accelerate import from test_hqq.py * quant config.py doc update * add hqqconfig to main_classes doc * make style * __init__ fix * ruff __init__ * skip_modules list * hqqconfig format fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * test_hqq.py remove mistral comment * remove self.using_multi_gpu is False * torch_dtype default val set and logger.info * hqq.py isinstance fix * remove torch=None * torch_device test_hqq * rename test_hqq * MODEL_ID in test_hqq * quantizer_hqq setattr fix * quantizer_hqq typo fix * imports quantizer_hqq.py * isinstance quantizer_hqq * hqq_layer.bias reformat quantizer_hqq * Step 2 as comment in quantizer_hqq * prepare_for_hqq_linear() comment * keep_in_fp32_modules fix * HqqHfQuantizer reformat * quantization.md hqqconfig * quantization.md model example reformat * quantization.md # space * quantization.md space }) * quantization.md space }) * quantization_config fix doc Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * axis value check in quantization_config * format * dynamic config explanation * quant config method in quantization.md * remove shard-level progress * .cuda fix modeling_utils * test_hqq fixes * make fix-copies --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
170 lines
6.9 KiB
Python
Executable File
170 lines
6.9 KiB
Python
Executable File
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import warnings
|
|
from typing import Dict, Optional, Union
|
|
|
|
from ..models.auto.configuration_auto import AutoConfig
|
|
from ..utils.quantization_config import (
|
|
AqlmConfig,
|
|
AwqConfig,
|
|
BitsAndBytesConfig,
|
|
EetqConfig,
|
|
GPTQConfig,
|
|
HqqConfig,
|
|
QuantizationConfigMixin,
|
|
QuantizationMethod,
|
|
QuantoConfig,
|
|
)
|
|
from .quantizer_aqlm import AqlmHfQuantizer
|
|
from .quantizer_awq import AwqQuantizer
|
|
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
|
|
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
|
from .quantizer_eetq import EetqHfQuantizer
|
|
from .quantizer_gptq import GptqHfQuantizer
|
|
from .quantizer_hqq import HqqHfQuantizer
|
|
from .quantizer_quanto import QuantoHfQuantizer
|
|
|
|
|
|
AUTO_QUANTIZER_MAPPING = {
|
|
"awq": AwqQuantizer,
|
|
"bitsandbytes_4bit": Bnb4BitHfQuantizer,
|
|
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
|
|
"gptq": GptqHfQuantizer,
|
|
"aqlm": AqlmHfQuantizer,
|
|
"quanto": QuantoHfQuantizer,
|
|
"eetq": EetqHfQuantizer,
|
|
"hqq": HqqHfQuantizer,
|
|
}
|
|
|
|
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|
"awq": AwqConfig,
|
|
"bitsandbytes_4bit": BitsAndBytesConfig,
|
|
"bitsandbytes_8bit": BitsAndBytesConfig,
|
|
"eetq": EetqConfig,
|
|
"gptq": GPTQConfig,
|
|
"aqlm": AqlmConfig,
|
|
"quanto": QuantoConfig,
|
|
"hqq": HqqConfig,
|
|
}
|
|
|
|
|
|
class AutoQuantizationConfig:
|
|
"""
|
|
The Auto-HF quantization config class that takes care of automatically dispatching to the correct
|
|
quantization config given a quantization config stored in a dictionary.
|
|
"""
|
|
|
|
@classmethod
|
|
def from_dict(cls, quantization_config_dict: Dict):
|
|
quant_method = quantization_config_dict.get("quant_method", None)
|
|
# We need a special care for bnb models to make sure everything is BC ..
|
|
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
|
|
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
|
|
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
|
|
elif quant_method is None:
|
|
raise ValueError(
|
|
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
|
|
)
|
|
|
|
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
|
|
raise ValueError(
|
|
f"Unknown quantization type, got {quant_method} - supported types are:"
|
|
f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
|
|
)
|
|
|
|
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
|
|
return target_cls.from_dict(quantization_config_dict)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
if getattr(model_config, "quantization_config", None) is None:
|
|
raise ValueError(
|
|
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
|
|
)
|
|
quantization_config_dict = model_config.quantization_config
|
|
quantization_config = cls.from_dict(quantization_config_dict)
|
|
# Update with potential kwargs that are passed through from_pretrained.
|
|
quantization_config.update(kwargs)
|
|
return quantization_config
|
|
|
|
|
|
class AutoHfQuantizer:
|
|
"""
|
|
The Auto-HF quantizer class that takes care of automatically instantiating to the correct
|
|
`HfQuantizer` given the `QuantizationConfig`.
|
|
"""
|
|
|
|
@classmethod
|
|
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
|
|
# Convert it to a QuantizationConfig if the q_config is a dict
|
|
if isinstance(quantization_config, dict):
|
|
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
|
|
|
quant_method = quantization_config.quant_method
|
|
|
|
# Again, we need a special care for bnb as we have a single quantization config
|
|
# class for both 4-bit and 8-bit quantization
|
|
if quant_method == QuantizationMethod.BITS_AND_BYTES:
|
|
if quantization_config.load_in_8bit:
|
|
quant_method += "_8bit"
|
|
else:
|
|
quant_method += "_4bit"
|
|
|
|
if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
|
|
raise ValueError(
|
|
f"Unknown quantization type, got {quant_method} - supported types are:"
|
|
f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
|
|
)
|
|
|
|
target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
|
|
return target_cls(quantization_config, **kwargs)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
return cls.from_config(quantization_config)
|
|
|
|
@classmethod
|
|
def merge_quantization_configs(
|
|
cls,
|
|
quantization_config: Union[dict, QuantizationConfigMixin],
|
|
quantization_config_from_args: Optional[QuantizationConfigMixin],
|
|
):
|
|
"""
|
|
handles situations where both quantization_config from args and quantization_config from model config are present.
|
|
"""
|
|
if quantization_config_from_args is not None:
|
|
warning_msg = (
|
|
"You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
|
|
" already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
|
|
)
|
|
else:
|
|
warning_msg = ""
|
|
|
|
if isinstance(quantization_config, dict):
|
|
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
|
|
|
if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None:
|
|
# special case for GPTQ / AWQ config collision
|
|
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
|
|
for attr, val in loading_attr_dict.items():
|
|
setattr(quantization_config, attr, val)
|
|
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
|
|
|
|
if warning_msg != "":
|
|
warnings.warn(warning_msg)
|
|
|
|
return quantization_config
|