diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index a1096c110df..51936708159 100755 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -247,6 +247,25 @@ class HfQuantizer(ABC): f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." ) + @staticmethod + def get_modules_to_not_convert( + model: "PreTrainedModel", + skip_modules: Optional[List[str]] = None, + keep_in_fp32_modules: Optional[List[str]] = None, + ): + from ..integrations import get_keys_to_not_convert + + modules_to_not_convert = [] + if skip_modules is None: + modules_to_not_convert = get_keys_to_not_convert(model) + else: + modules_to_not_convert = skip_modules + + if keep_in_fp32_modules is not None: + modules_to_not_convert.extend(keep_in_fp32_modules) + + return modules_to_not_convert + @property def is_qat_trainable(self) -> bool: """Flag indicating whether the quantized model can carry out quantization aware training""" diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 9109fccb575..28460ac38ed 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib.metadata -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from packaging import version @@ -96,13 +96,14 @@ class AwqQuantizer(HfQuantizer): logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.") return torch_dtype - def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): - from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear + def _process_model_before_weight_loading( + self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs + ): + from ..integrations import replace_quantization_scales, replace_with_awq_linear - self.modules_to_not_convert = get_keys_to_not_convert(model) - - if self.quantization_config.modules_to_not_convert is not None: - self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) model, has_been_replaced = replace_with_awq_linear( model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert diff --git a/src/transformers/quantizers/quantizer_bitnet.py b/src/transformers/quantizers/quantizer_bitnet.py index 3607caa0073..e56bb161ac4 100644 --- a/src/transformers/quantizers/quantizer_bitnet.py +++ b/src/transformers/quantizers/quantizer_bitnet.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from .base import HfQuantizer @@ -81,16 +81,14 @@ class BitNetHfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", - device_map, - keep_in_fp32_modules: List[str] = [], + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): - from ..integrations import get_keys_to_not_convert, replace_with_bitnet_linear + from ..integrations import replace_with_bitnet_linear - self.modules_to_not_convert = get_keys_to_not_convert(model) - - if self.quantization_config.modules_to_not_convert is not None: - self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) model = replace_with_bitnet_linear( model, diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index ab04a295460..7fb9176c467 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -288,23 +288,16 @@ class Bnb4BitHfQuantizer(HfQuantizer): self, model: "PreTrainedModel", device_map, - keep_in_fp32_modules: List[str] = [], + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): - from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear + from ..integrations import replace_with_bnb_linear llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload - # We keep some modules such as the lm_head in their original dtype for numerical stability reasons - if self.quantization_config.llm_int8_skip_modules is None: - self.modules_to_not_convert = get_keys_to_not_convert(model) - else: - self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules - - if not isinstance(self.modules_to_not_convert, list): - self.modules_to_not_convert = [self.modules_to_not_convert] - - self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules + ) # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` if isinstance(device_map, dict) and len(device_map.keys()) > 1: diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 093d612b914..cac339b16b9 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -245,23 +245,16 @@ class Bnb8BitHfQuantizer(HfQuantizer): self, model: "PreTrainedModel", device_map, - keep_in_fp32_modules: List[str] = [], + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): - from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear + from ..integrations import replace_with_bnb_linear llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload - # We keep some modules such as the lm_head in their original dtype for numerical stability reasons - if self.quantization_config.llm_int8_skip_modules is None: - self.modules_to_not_convert = get_keys_to_not_convert(model) - else: - self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules - - if not isinstance(self.modules_to_not_convert, list): - self.modules_to_not_convert = [self.modules_to_not_convert] - - self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules + ) # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` if isinstance(device_map, dict) and len(device_map.keys()) > 1: diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py index 7dfce75c373..988f90789ac 100644 --- a/src/transformers/quantizers/quantizer_eetq.py +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -155,16 +155,14 @@ class EetqHfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", - device_map, - keep_in_fp32_modules: List[str] = [], + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): - from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear + from ..integrations import replace_with_eetq_linear - self.modules_to_not_convert = get_keys_to_not_convert(model) - - if self.quantization_config.modules_to_not_convert is not None: - self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) model = replace_with_eetq_linear( model, diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py index 07d5ce87ef6..dd0927765d1 100644 --- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py +++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py @@ -161,16 +161,14 @@ class FbgemmFp8HfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", - device_map, - keep_in_fp32_modules: List[str] = [], + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): - from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear + from ..integrations import replace_with_fbgemm_fp8_linear - self.modules_to_not_convert = get_keys_to_not_convert(model) - - if self.quantization_config.modules_to_not_convert is not None: - self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) model = replace_with_fbgemm_fp8_linear( model, diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 7816ed2f583..ac6b7355121 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -162,16 +162,14 @@ class FineGrainedFP8HfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", - device_map, - modules_to_not_convert: List[str] = [], + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): from ..integrations.finegrained_fp8 import replace_with_fp8_linear - self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert - - if self.quantization_config.modules_to_not_convert: - self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) model = replace_with_fp8_linear( model, diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 775fea8f490..93ab958a30c 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -273,12 +273,8 @@ class HqqHfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", - device_map, - keep_in_fp32_modules: List[str] = None, **kwargs, ): - keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else [] - # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param(). # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config) model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py index 230e8efe150..be760f0d430 100644 --- a/src/transformers/quantizers/quantizer_quanto.py +++ b/src/transformers/quantizers/quantizer_quanto.py @@ -177,20 +177,13 @@ class QuantoHfQuantizer(HfQuantizer): ) def _process_model_before_weight_loading( - self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs + self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs ): - from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers + from ..integrations import replace_with_quanto_layers - # We keep some modules such as the lm_head in their original dtype for numerical stability reasons - if self.quantization_config.modules_to_not_convert is None: - self.modules_to_not_convert = get_keys_to_not_convert(model) - else: - self.modules_to_not_convert = self.quantization_config.modules_to_not_convert - - if not isinstance(self.modules_to_not_convert, list): - self.modules_to_not_convert = [self.modules_to_not_convert] - - self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) model, _ = replace_with_quanto_layers( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config diff --git a/src/transformers/quantizers/quantizer_spqr.py b/src/transformers/quantizers/quantizer_spqr.py index 7252e9808ee..4cf1193edbf 100644 --- a/src/transformers/quantizers/quantizer_spqr.py +++ b/src/transformers/quantizers/quantizer_spqr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from .base import HfQuantizer @@ -65,12 +65,17 @@ class SpQRHfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) + replace_with_spqr_linear( model, quantization_config=self.quantization_config, - modules_to_not_convert=self.quantization_config.modules_to_not_convert, + modules_to_not_convert=self.modules_to_not_convert, ) model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 8439e68a908..e233f0689aa 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -13,7 +13,7 @@ # limitations under the License. import importlib import types -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union from packaging import version @@ -144,14 +144,12 @@ class TorchAoHfQuantizer(HfQuantizer): max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory - def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): - from ..integrations import get_keys_to_not_convert - - self.modules_to_not_convert = get_keys_to_not_convert(model) - - if self.quantization_config.modules_to_not_convert is not None: - self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) - + def _process_model_before_weight_loading( + self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs + ): + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) return def check_quantized_param( diff --git a/src/transformers/quantizers/quantizer_vptq.py b/src/transformers/quantizers/quantizer_vptq.py index 1672c3ebc5a..85483357448 100644 --- a/src/transformers/quantizers/quantizer_vptq.py +++ b/src/transformers/quantizers/quantizer_vptq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from .base import HfQuantizer @@ -68,6 +68,7 @@ class VptqHfQuantizer(HfQuantizer): def _process_model_before_weight_loading( self, model: "PreTrainedModel", + keep_in_fp32_modules: Optional[List[str]] = None, **kwargs, ): """ @@ -76,14 +77,14 @@ class VptqHfQuantizer(HfQuantizer): """ from ..integrations import replace_with_vptq_linear - modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + ( - self.quantization_config.modules_to_not_convert or [] + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) replace_with_vptq_linear( model, quantization_config=self.quantization_config, - modules_to_not_convert=modules_to_not_convert, + modules_to_not_convert=self.modules_to_not_convert, ) model.config.quantization_config = self.quantization_config diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e8fe5b422c3..0988d8ac147 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1424,8 +1424,6 @@ class HiggsConfig(QuantizationConfigMixin): tune_metadata: Optional[Dict[str, Any]] = None, **kwargs, ): - if modules_to_not_convert is None: - modules_to_not_convert = ["lm_head"] if tune_metadata is None: tune_metadata = {} self.quant_method = QuantizationMethod.HIGGS @@ -1652,8 +1650,6 @@ class SpQRConfig(QuantizationConfigMixin): self.bits = bits self.beta1 = beta1 self.beta2 = beta2 - if modules_to_not_convert is None: - modules_to_not_convert = [] self.modules_to_not_convert = modules_to_not_convert self.post_init() @@ -1674,10 +1670,6 @@ class SpQRConfig(QuantizationConfigMixin): raise ValueError("SpQR currently only supports beta1 = 16") if self.beta2 != 16: raise ValueError("SpQR currently only supports beta2 = 16") - - if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list): - raise ValueError("modules_to_not_convert must be a list of strings") - if not isinstance(self.shapes, dict): raise TypeError("shapes must be a dict")