diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index da571a7c5f8..1f1fff8bc39 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -767,12 +767,12 @@ if TYPE_CHECKING: EetqConfig, FbgemmFp8Config, FineGrainedFP8Config, + FPQuantConfig, GPTQConfig, HiggsConfig, HqqConfig, QuantoConfig, QuarkConfig, - FPQuantConfig, SpQRConfig, TorchAoConfig, VptqConfig, diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 35d80166986..822ec7af459 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -27,6 +27,7 @@ from ..utils.quantization_config import ( EetqConfig, FbgemmFp8Config, FineGrainedFP8Config, + FPQuantConfig, GPTQConfig, HiggsConfig, HqqConfig, @@ -34,7 +35,6 @@ from ..utils.quantization_config import ( QuantizationMethod, QuantoConfig, QuarkConfig, - FPQuantConfig, SpQRConfig, TorchAoConfig, VptqConfig, @@ -50,12 +50,12 @@ from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer +from .quantizer_fp_quant import FPQuantHfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer from .quantizer_quark import QuarkHfQuantizer -from .quantizer_fp_quant import FPQuantHfQuantizer from .quantizer_spqr import SpQRHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_vptq import VptqHfQuantizer diff --git a/src/transformers/quantizers/quantizer_fp_quant.py b/src/transformers/quantizers/quantizer_fp_quant.py index 236d15a94c5..2362a6aee44 100644 --- a/src/transformers/quantizers/quantizer_fp_quant.py +++ b/src/transformers/quantizers/quantizer_fp_quant.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import TYPE_CHECKING, Any, Dict, List, Optional -from ..utils.logging import tqdm from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -21,7 +20,7 @@ from .quantizers_utils import get_module_from_name if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_fp_quant_available, is_fp_quant_available, is_qutlass_available, is_torch_available, logging +from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin @@ -126,7 +125,9 @@ class FPQuantHfQuantizer(HfQuantizer): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): from fp_quant import FPQuantLinear - fp_quant_modules = {name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear)} + fp_quant_modules = { + name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear) + } for name, module in fp_quant_modules.items(): if not self.quantization_config.store_master_weights and module.weight is not None: module.weight = None diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 945a05cf27d..62f91894c28 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -155,6 +155,7 @@ from .import_utils import ( is_flash_attn_greater_or_equal_2_10, is_flax_available, is_flute_available, + is_fp_quant_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, @@ -199,7 +200,6 @@ from .import_utils import ( is_pytest_available, is_pytorch_quantization_available, is_quark_available, - is_fp_quant_available, is_qutlass_available, is_rich_available, is_rjieba_available,