[quants] refactor logic for modules_to_not_convert (#36672)

This commit is contained in:
Marc Sun 2025-03-12 23:43:30 +01:00 committed by GitHub
parent bc3253f076
commit cc3a361b46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 81 additions and 98 deletions

View File

@ -247,6 +247,25 @@ class HfQuantizer(ABC):
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)
@staticmethod
def get_modules_to_not_convert(
model: "PreTrainedModel",
skip_modules: Optional[List[str]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
):
from ..integrations import get_keys_to_not_convert
modules_to_not_convert = []
if skip_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = skip_modules
if keep_in_fp32_modules is not None:
modules_to_not_convert.extend(keep_in_fp32_modules)
return modules_to_not_convert
@property
def is_qat_trainable(self) -> bool:
"""Flag indicating whether the quantized model can carry out quantization aware training"""

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
from packaging import version
@ -96,13 +96,14 @@ class AwqQuantizer(HfQuantizer):
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
return torch_dtype
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
):
from ..integrations import replace_quantization_scales, replace_with_awq_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
model, has_been_replaced = replace_with_awq_linear(
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from .base import HfQuantizer
@ -81,16 +81,14 @@ class BitNetHfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bitnet_linear
from ..integrations import replace_with_bitnet_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
model = replace_with_bitnet_linear(
model,

View File

@ -288,23 +288,16 @@ class Bnb4BitHfQuantizer(HfQuantizer):
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
from ..integrations import replace_with_bnb_linear
llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.llm_int8_skip_modules is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:

View File

@ -245,23 +245,16 @@ class Bnb8BitHfQuantizer(HfQuantizer):
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
from ..integrations import replace_with_bnb_linear
llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.llm_int8_skip_modules is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:

View File

@ -155,16 +155,14 @@ class EetqHfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear
from ..integrations import replace_with_eetq_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
model = replace_with_eetq_linear(
model,

View File

@ -161,16 +161,14 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear
from ..integrations import replace_with_fbgemm_fp8_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
model = replace_with_fbgemm_fp8_linear(
model,

View File

@ -162,16 +162,14 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
modules_to_not_convert: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations.finegrained_fp8 import replace_with_fp8_linear
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
if self.quantization_config.modules_to_not_convert:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
model = replace_with_fp8_linear(
model,

View File

@ -273,12 +273,8 @@ class HqqHfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = None,
**kwargs,
):
keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else []
# Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
# prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)

View File

@ -177,20 +177,13 @@ class QuantoHfQuantizer(HfQuantizer):
)
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
):
from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers
from ..integrations import replace_with_quanto_layers
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.modules_to_not_convert is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
model, _ = replace_with_quanto_layers(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional
from .base import HfQuantizer
@ -65,12 +65,17 @@ class SpQRHfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
replace_with_spqr_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
modules_to_not_convert=self.modules_to_not_convert,
)
model.config.quantization_config = self.quantization_config

View File

@ -13,7 +13,7 @@
# limitations under the License.
import importlib
import types
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Optional, Union
from packaging import version
@ -144,14 +144,12 @@ class TorchAoHfQuantizer(HfQuantizer):
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
):
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
return
def check_quantized_param(

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional
from .base import HfQuantizer
@ -68,6 +68,7 @@ class VptqHfQuantizer(HfQuantizer):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
"""
@ -76,14 +77,14 @@ class VptqHfQuantizer(HfQuantizer):
"""
from ..integrations import replace_with_vptq_linear
modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + (
self.quantization_config.modules_to_not_convert or []
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
replace_with_vptq_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=modules_to_not_convert,
modules_to_not_convert=self.modules_to_not_convert,
)
model.config.quantization_config = self.quantization_config

View File

@ -1424,8 +1424,6 @@ class HiggsConfig(QuantizationConfigMixin):
tune_metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
if modules_to_not_convert is None:
modules_to_not_convert = ["lm_head"]
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS
@ -1652,8 +1650,6 @@ class SpQRConfig(QuantizationConfigMixin):
self.bits = bits
self.beta1 = beta1
self.beta2 = beta2
if modules_to_not_convert is None:
modules_to_not_convert = []
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
@ -1674,10 +1670,6 @@ class SpQRConfig(QuantizationConfigMixin):
raise ValueError("SpQR currently only supports beta1 = 16")
if self.beta2 != 16:
raise ValueError("SpQR currently only supports beta2 = 16")
if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list):
raise ValueError("modules_to_not_convert must be a list of strings")
if not isinstance(self.shapes, dict):
raise TypeError("shapes must be a dict")