mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[quants] refactor logic for modules_to_not_convert (#36672)
This commit is contained in:
parent
bc3253f076
commit
cc3a361b46
@ -247,6 +247,25 @@ class HfQuantizer(ABC):
|
||||
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_modules_to_not_convert(
|
||||
model: "PreTrainedModel",
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert
|
||||
|
||||
modules_to_not_convert = []
|
||||
if skip_modules is None:
|
||||
modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
modules_to_not_convert = skip_modules
|
||||
|
||||
if keep_in_fp32_modules is not None:
|
||||
modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
||||
return modules_to_not_convert
|
||||
|
||||
@property
|
||||
def is_qat_trainable(self) -> bool:
|
||||
"""Flag indicating whether the quantized model can carry out quantization aware training"""
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.metadata
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -96,13 +96,14 @@ class AwqQuantizer(HfQuantizer):
|
||||
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
|
||||
def _process_model_before_weight_loading(
|
||||
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
|
||||
):
|
||||
from ..integrations import replace_quantization_scales, replace_with_awq_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
model, has_been_replaced = replace_with_awq_linear(
|
||||
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
@ -81,16 +81,14 @@ class BitNetHfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_bitnet_linear
|
||||
from ..integrations import replace_with_bitnet_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
model = replace_with_bitnet_linear(
|
||||
model,
|
||||
|
@ -288,23 +288,16 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
|
||||
from ..integrations import replace_with_bnb_linear
|
||||
|
||||
llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
|
||||
|
||||
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||
if self.quantization_config.llm_int8_skip_modules is None:
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
|
||||
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
|
@ -245,23 +245,16 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
|
||||
from ..integrations import replace_with_bnb_linear
|
||||
|
||||
llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
|
||||
|
||||
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||
if self.quantization_config.llm_int8_skip_modules is None:
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
|
||||
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
|
@ -155,16 +155,14 @@ class EetqHfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear
|
||||
from ..integrations import replace_with_eetq_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
model = replace_with_eetq_linear(
|
||||
model,
|
||||
|
@ -161,16 +161,14 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear
|
||||
from ..integrations import replace_with_fbgemm_fp8_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
model = replace_with_fbgemm_fp8_linear(
|
||||
model,
|
||||
|
@ -162,16 +162,14 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
modules_to_not_convert: List[str] = [],
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations.finegrained_fp8 import replace_with_fp8_linear
|
||||
|
||||
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
|
||||
|
||||
if self.quantization_config.modules_to_not_convert:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
model = replace_with_fp8_linear(
|
||||
model,
|
||||
|
@ -273,12 +273,8 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else []
|
||||
|
||||
# Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
|
||||
# prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
|
||||
model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
|
||||
|
@ -177,20 +177,13 @@ class QuantoHfQuantizer(HfQuantizer):
|
||||
)
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs
|
||||
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers
|
||||
from ..integrations import replace_with_quanto_layers
|
||||
|
||||
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||
if self.quantization_config.modules_to_not_convert is None:
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
model, _ = replace_with_quanto_layers(
|
||||
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
@ -65,12 +65,17 @@ class SpQRHfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
replace_with_spqr_linear(
|
||||
model,
|
||||
quantization_config=self.quantization_config,
|
||||
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -144,14 +144,12 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
from ..integrations import get_keys_to_not_convert
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
|
||||
):
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
return
|
||||
|
||||
def check_quantized_param(
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
@ -68,6 +68,7 @@ class VptqHfQuantizer(HfQuantizer):
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -76,14 +77,14 @@ class VptqHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
from ..integrations import replace_with_vptq_linear
|
||||
|
||||
modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + (
|
||||
self.quantization_config.modules_to_not_convert or []
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
replace_with_vptq_linear(
|
||||
model,
|
||||
quantization_config=self.quantization_config,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
|
@ -1424,8 +1424,6 @@ class HiggsConfig(QuantizationConfigMixin):
|
||||
tune_metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = ["lm_head"]
|
||||
if tune_metadata is None:
|
||||
tune_metadata = {}
|
||||
self.quant_method = QuantizationMethod.HIGGS
|
||||
@ -1652,8 +1650,6 @@ class SpQRConfig(QuantizationConfigMixin):
|
||||
self.bits = bits
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.post_init()
|
||||
|
||||
@ -1674,10 +1670,6 @@ class SpQRConfig(QuantizationConfigMixin):
|
||||
raise ValueError("SpQR currently only supports beta1 = 16")
|
||||
if self.beta2 != 16:
|
||||
raise ValueError("SpQR currently only supports beta2 = 16")
|
||||
|
||||
if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list):
|
||||
raise ValueError("modules_to_not_convert must be a list of strings")
|
||||
|
||||
if not isinstance(self.shapes, dict):
|
||||
raise TypeError("shapes must be a dict")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user