diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 1e54af25652..992f629e5a1 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -93,9 +93,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] QuarkConfig -## QuartetConfig +## FPQuantConfig -[[autodoc]] QuartetConfig +[[autodoc]] FPQuantConfig ## AutoRoundConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0b6cdafe126..da571a7c5f8 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -272,7 +272,7 @@ _import_structure = { "HqqConfig", "QuantoConfig", "QuarkConfig", - "QuartetConfig", + "FPQuantConfig", "SpQRConfig", "TorchAoConfig", "VptqConfig", @@ -772,7 +772,7 @@ if TYPE_CHECKING: HqqConfig, QuantoConfig, QuarkConfig, - QuartetConfig, + FPQuantConfig, SpQRConfig, TorchAoConfig, VptqConfig, diff --git a/src/transformers/integrations/quartet.py b/src/transformers/integrations/fp_quant.py similarity index 73% rename from src/transformers/integrations/quartet.py rename to src/transformers/integrations/fp_quant.py index 016f74e7e8f..1c4be8755d0 100644 --- a/src/transformers/integrations/quartet.py +++ b/src/transformers/integrations/fp_quant.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"Quartet integration file" +"FPQuant integration file" from ..utils import ( - is_quartet_qat_available, + is_fp_quant_available, is_torch_available, ) @@ -23,27 +23,27 @@ if is_torch_available(): pass -if is_quartet_qat_available(): - from quartet_qat import QuartetConfig as QuartetLinearConfig - from quartet_qat import QuartetDtype +if is_fp_quant_available(): + from fp_quant import FPQuantConfig as FPQuantLinearConfig + from fp_quant import FPQuantDtype -from transformers.utils.quantization_config import QuartetConfig +from transformers.utils.quantization_config import FPQuantConfig -def adapt_quartet_config(config: QuartetConfig): +def adapt_fp_quant_config(config: FPQuantConfig): if config.forward_dtype == "mxfp4": - forward_dtype = QuartetDtype.MXFP4 + forward_dtype = FPQuantDtype.MXFP4 else: raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}") if config.backward_dtype == "mxfp4": - backward_dtype = QuartetDtype.MXFP4 + backward_dtype = FPQuantDtype.MXFP4 elif config.backward_dtype == "bf16": - backward_dtype = QuartetDtype.BF16 + backward_dtype = FPQuantDtype.BF16 else: raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}") - return QuartetLinearConfig( + return FPQuantLinearConfig( forward_dtype=forward_dtype, forward_method=config.forward_method, backward_dtype=backward_dtype, diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 91ae52dd942..35d80166986 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -34,7 +34,7 @@ from ..utils.quantization_config import ( QuantizationMethod, QuantoConfig, QuarkConfig, - QuartetConfig, + FPQuantConfig, SpQRConfig, TorchAoConfig, VptqConfig, @@ -55,7 +55,7 @@ from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer from .quantizer_quark import QuarkHfQuantizer -from .quantizer_quartet import QuartetHfQuantizer +from .quantizer_fp_quant import FPQuantHfQuantizer from .quantizer_spqr import SpQRHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_vptq import VptqHfQuantizer @@ -69,7 +69,7 @@ AUTO_QUANTIZER_MAPPING = { "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, "quark": QuarkHfQuantizer, - "quartet": QuartetHfQuantizer, + "fp_quant": FPQuantHfQuantizer, "eetq": EetqHfQuantizer, "higgs": HiggsHfQuantizer, "hqq": HqqHfQuantizer, @@ -92,7 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { "aqlm": AqlmConfig, "quanto": QuantoConfig, "quark": QuarkConfig, - "quartet": QuartetConfig, + "fp_quant": FPQuantConfig, "hqq": HqqConfig, "compressed-tensors": CompressedTensorsConfig, "fbgemm_fp8": FbgemmFp8Config, diff --git a/src/transformers/quantizers/quantizer_quartet.py b/src/transformers/quantizers/quantizer_fp_quant.py similarity index 68% rename from src/transformers/quantizers/quantizer_quartet.py rename to src/transformers/quantizers/quantizer_fp_quant.py index bce2f28bd27..236d15a94c5 100644 --- a/src/transformers/quantizers/quantizer_quartet.py +++ b/src/transformers/quantizers/quantizer_fp_quant.py @@ -21,7 +21,7 @@ from .quantizers_utils import get_module_from_name if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_quartet_available, is_quartet_qat_available, is_qutlass_available, is_torch_available, logging +from ..utils import is_fp_quant_available, is_fp_quant_available, is_qutlass_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin @@ -31,7 +31,7 @@ if is_torch_available(): logger = logging.get_logger(__name__) -class QuartetHfQuantizer(HfQuantizer): +class FPQuantHfQuantizer(HfQuantizer): """ Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models. """ @@ -39,7 +39,7 @@ class QuartetHfQuantizer(HfQuantizer): requires_calibration = False requires_parameters_quantization = True is_qat_trainable = True - required_packages = ["qutlass", "quartet_qat", "quartet"] + required_packages = ["qutlass", "fp_quant"] def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) @@ -48,36 +48,33 @@ class QuartetHfQuantizer(HfQuantizer): def validate_environment(self, device_map, **kwargs): if not torch.cuda.is_available(): raise NotImplementedError( - "Quartet quantization is only supported on GPU. Please use a different quantizer." + "FPQuant quantization is only supported on GPU. Please use a different quantizer." ) if not is_qutlass_available(): - raise ImportError("Using `quartet` quantization requires qutlass: `pip install qutlass`") + raise ImportError("Using `fp_quant` quantization requires qutlass: `pip install qutlass`") - if not is_quartet_qat_available(): - raise ImportError("Using `quartet` quantization requires quartet_qat: `pip install quartet_qat`") - - if not is_quartet_available(): - raise ImportError("Using `quartet` quantization requires quartet: `pip install quartet`") + if not is_fp_quant_available(): + raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`") if device_map is None: raise ValueError( - "You are attempting to load a Quartet model without setting device_map." + "You are attempting to load a FPQuant model without setting device_map." " Please set device_map comprised of 'cuda' devices." ) elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): raise ValueError( - "You are attempting to load a Quartet model with a device_map that contains a CPU or disk device." + "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device." " This is not supported. Please remove the CPU or disk device from the device_map." ) def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: - logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.") - torch_dtype = torch.float16 - elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16: + logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.") + torch_dtype = torch.bfloat16 + elif torch_dtype != torch.bfloat16: raise ValueError( - f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`." + f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`." ) return torch_dtype @@ -91,10 +88,10 @@ class QuartetHfQuantizer(HfQuantizer): state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, ): - from quartet_qat import QuartetLinear + from fp_quant import FPQuantLinear module, _ = get_module_from_name(model, param_name) - assert isinstance(module, QuartetLinear), f"Module {param_name} is not a QuartetLinear somehow..." + assert isinstance(module, FPQuantLinear), f"Module {param_name} is not a FPQuantLinear somehow..." if param_name.endswith(".qweight"): module.qweight = torch.nn.Parameter( @@ -116,34 +113,34 @@ class QuartetHfQuantizer(HfQuantizer): model: "PreTrainedModel", **kwargs, ): - from quartet_qat import replace_with_quartet_linear + from fp_quant import replace_with_fp_quant_linear - from ..integrations.quartet import adapt_quartet_config + from ..integrations.fp_quant import adapt_fp_quant_config - replace_with_quartet_linear( + replace_with_fp_quant_linear( model, - quartet_linear_config=adapt_quartet_config(self.quantization_config), + fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config), ) model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - from quartet_qat import QuartetLinear + from fp_quant import FPQuantLinear - quartet_modules = {name: module for name, module in model.named_modules() if isinstance(module, QuartetLinear)} - for name, module in tqdm(quartet_modules.items(), desc="Pre-processing Quartet modules", leave=False): + fp_quant_modules = {name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear)} + for name, module in fp_quant_modules.items(): if not self.quantization_config.store_master_weights and module.weight is not None: module.weight = None def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - from quartet_qat import QuartetLinear + from fp_quant import FPQuantLinear - quartet_names = {name for name, module in model.named_modules() if isinstance(module, QuartetLinear)} + fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)} def should_exclude(key: str) -> bool: if key.endswith(".weight") or key.endswith(".bias"): return False full_key = f"{prefix}.{key}" - return any(name in key or name in full_key for name in quartet_names) + return any(name in key or name in full_key for name in fp_quant_names) return [key for key in missing_keys if not should_exclude(key)] @@ -162,11 +159,11 @@ class QuartetHfQuantizer(HfQuantizer): state_dict: Dict[str, Any], **kwargs, ) -> bool: - from quartet_qat import QuartetLinear + from fp_quant import FPQuantLinear module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, QuartetLinear) and tensor_name in ["weight", "qweight"]: - # Only quantize weights of QuartetLinear modules that are not already quantized + if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight"]: + # Only quantize weights of FPQuantLinear modules that are not already quantized return True else: return False diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 670e533353c..945a05cf27d 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -199,8 +199,7 @@ from .import_utils import ( is_pytest_available, is_pytorch_quantization_available, is_quark_available, - is_quartet_available, - is_quartet_qat_available, + is_fp_quant_available, is_qutlass_available, is_rich_available, is_rjieba_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 7d031e2c06e..b58dbff7e2b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -170,8 +170,7 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round", # `importlib.metadata.version` doesn't work with `awq` _auto_awq_available = importlib.util.find_spec("awq") is not None _quark_available = _is_package_available("quark") -_quartet_available = _is_package_available("quartet") -_quartet_qat_available = _is_package_available("quartet_qat") +_fp_quant_available = _is_package_available("fp_quant") _qutlass_available = _is_package_available("qutlass") _is_optimum_quanto_available = False try: @@ -1152,12 +1151,8 @@ def is_quark_available(): return _quark_available -def is_quartet_available(): - return _quartet_available - - -def is_quartet_qat_available(): - return _quartet_qat_available +def is_fp_quant_available(): + return _fp_quant_available def is_qutlass_available(): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 260cf6d115a..52a34c581b0 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -63,7 +63,7 @@ class QuantizationMethod(str, Enum): SPQR = "spqr" FP8 = "fp8" QUARK = "quark" - QUARTET = "quartet" + FPQUANT = "fp_quant" AUTOROUND = "auto-round" @@ -1552,9 +1552,9 @@ class HiggsConfig(QuantizationConfigMixin): @dataclass -class QuartetConfig(QuantizationConfigMixin): +class FPQuantConfig(QuantizationConfigMixin): """ - QuartetConfig is a configuration class for quantization using the Quartet method. + FPQuantConfig is a configuration class for quantization using the FPQuant method. Args: forward_dtype (`str`, *optional*, defaults to `"mxfp4"`): @@ -1587,7 +1587,7 @@ class QuartetConfig(QuantizationConfigMixin): self.hadamard_group_size = hadamard_group_size self.modules_to_not_convert = modules_to_not_convert - self.quant_method = QuantizationMethod.QUARTET + self.quant_method = QuantizationMethod.FPQUANT self.post_init() def post_init(self):