quartet -> fp_quant

This commit is contained in:
Andrei Panferov 2025-07-01 10:29:15 +02:00
parent 7199b608f2
commit c2b5b29a8d
8 changed files with 55 additions and 64 deletions

View File

@ -93,9 +93,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] QuarkConfig [[autodoc]] QuarkConfig
## QuartetConfig ## FPQuantConfig
[[autodoc]] QuartetConfig [[autodoc]] FPQuantConfig
## AutoRoundConfig ## AutoRoundConfig

View File

@ -272,7 +272,7 @@ _import_structure = {
"HqqConfig", "HqqConfig",
"QuantoConfig", "QuantoConfig",
"QuarkConfig", "QuarkConfig",
"QuartetConfig", "FPQuantConfig",
"SpQRConfig", "SpQRConfig",
"TorchAoConfig", "TorchAoConfig",
"VptqConfig", "VptqConfig",
@ -772,7 +772,7 @@ if TYPE_CHECKING:
HqqConfig, HqqConfig,
QuantoConfig, QuantoConfig,
QuarkConfig, QuarkConfig,
QuartetConfig, FPQuantConfig,
SpQRConfig, SpQRConfig,
TorchAoConfig, TorchAoConfig,
VptqConfig, VptqConfig,

View File

@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"Quartet integration file" "FPQuant integration file"
from ..utils import ( from ..utils import (
is_quartet_qat_available, is_fp_quant_available,
is_torch_available, is_torch_available,
) )
@ -23,27 +23,27 @@ if is_torch_available():
pass pass
if is_quartet_qat_available(): if is_fp_quant_available():
from quartet_qat import QuartetConfig as QuartetLinearConfig from fp_quant import FPQuantConfig as FPQuantLinearConfig
from quartet_qat import QuartetDtype from fp_quant import FPQuantDtype
from transformers.utils.quantization_config import QuartetConfig from transformers.utils.quantization_config import FPQuantConfig
def adapt_quartet_config(config: QuartetConfig): def adapt_fp_quant_config(config: FPQuantConfig):
if config.forward_dtype == "mxfp4": if config.forward_dtype == "mxfp4":
forward_dtype = QuartetDtype.MXFP4 forward_dtype = FPQuantDtype.MXFP4
else: else:
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}") raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
if config.backward_dtype == "mxfp4": if config.backward_dtype == "mxfp4":
backward_dtype = QuartetDtype.MXFP4 backward_dtype = FPQuantDtype.MXFP4
elif config.backward_dtype == "bf16": elif config.backward_dtype == "bf16":
backward_dtype = QuartetDtype.BF16 backward_dtype = FPQuantDtype.BF16
else: else:
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}") raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
return QuartetLinearConfig( return FPQuantLinearConfig(
forward_dtype=forward_dtype, forward_dtype=forward_dtype,
forward_method=config.forward_method, forward_method=config.forward_method,
backward_dtype=backward_dtype, backward_dtype=backward_dtype,

View File

@ -34,7 +34,7 @@ from ..utils.quantization_config import (
QuantizationMethod, QuantizationMethod,
QuantoConfig, QuantoConfig,
QuarkConfig, QuarkConfig,
QuartetConfig, FPQuantConfig,
SpQRConfig, SpQRConfig,
TorchAoConfig, TorchAoConfig,
VptqConfig, VptqConfig,
@ -55,7 +55,7 @@ from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_quark import QuarkHfQuantizer from .quantizer_quark import QuarkHfQuantizer
from .quantizer_quartet import QuartetHfQuantizer from .quantizer_fp_quant import FPQuantHfQuantizer
from .quantizer_spqr import SpQRHfQuantizer from .quantizer_spqr import SpQRHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer from .quantizer_vptq import VptqHfQuantizer
@ -69,7 +69,7 @@ AUTO_QUANTIZER_MAPPING = {
"aqlm": AqlmHfQuantizer, "aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer, "quanto": QuantoHfQuantizer,
"quark": QuarkHfQuantizer, "quark": QuarkHfQuantizer,
"quartet": QuartetHfQuantizer, "fp_quant": FPQuantHfQuantizer,
"eetq": EetqHfQuantizer, "eetq": EetqHfQuantizer,
"higgs": HiggsHfQuantizer, "higgs": HiggsHfQuantizer,
"hqq": HqqHfQuantizer, "hqq": HqqHfQuantizer,
@ -92,7 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"aqlm": AqlmConfig, "aqlm": AqlmConfig,
"quanto": QuantoConfig, "quanto": QuantoConfig,
"quark": QuarkConfig, "quark": QuarkConfig,
"quartet": QuartetConfig, "fp_quant": FPQuantConfig,
"hqq": HqqConfig, "hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config, "fbgemm_fp8": FbgemmFp8Config,

View File

@ -21,7 +21,7 @@ from .quantizers_utils import get_module_from_name
if TYPE_CHECKING: if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..utils import is_quartet_available, is_quartet_qat_available, is_qutlass_available, is_torch_available, logging from ..utils import is_fp_quant_available, is_fp_quant_available, is_qutlass_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin from ..utils.quantization_config import QuantizationConfigMixin
@ -31,7 +31,7 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class QuartetHfQuantizer(HfQuantizer): class FPQuantHfQuantizer(HfQuantizer):
""" """
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models. Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
""" """
@ -39,7 +39,7 @@ class QuartetHfQuantizer(HfQuantizer):
requires_calibration = False requires_calibration = False
requires_parameters_quantization = True requires_parameters_quantization = True
is_qat_trainable = True is_qat_trainable = True
required_packages = ["qutlass", "quartet_qat", "quartet"] required_packages = ["qutlass", "fp_quant"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs) super().__init__(quantization_config, **kwargs)
@ -48,36 +48,33 @@ class QuartetHfQuantizer(HfQuantizer):
def validate_environment(self, device_map, **kwargs): def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError( raise NotImplementedError(
"Quartet quantization is only supported on GPU. Please use a different quantizer." "FPQuant quantization is only supported on GPU. Please use a different quantizer."
) )
if not is_qutlass_available(): if not is_qutlass_available():
raise ImportError("Using `quartet` quantization requires qutlass: `pip install qutlass`") raise ImportError("Using `fp_quant` quantization requires qutlass: `pip install qutlass`")
if not is_quartet_qat_available(): if not is_fp_quant_available():
raise ImportError("Using `quartet` quantization requires quartet_qat: `pip install quartet_qat`") raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
if not is_quartet_available():
raise ImportError("Using `quartet` quantization requires quartet: `pip install quartet`")
if device_map is None: if device_map is None:
raise ValueError( raise ValueError(
"You are attempting to load a Quartet model without setting device_map." "You are attempting to load a FPQuant model without setting device_map."
" Please set device_map comprised of 'cuda' devices." " Please set device_map comprised of 'cuda' devices."
) )
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError( raise ValueError(
"You are attempting to load a Quartet model with a device_map that contains a CPU or disk device." "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map." " This is not supported. Please remove the CPU or disk device from the device_map."
) )
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None: if torch_dtype is None:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.") logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.")
torch_dtype = torch.float16 torch_dtype = torch.bfloat16
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16: elif torch_dtype != torch.bfloat16:
raise ValueError( raise ValueError(
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`." f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`."
) )
return torch_dtype return torch_dtype
@ -91,10 +88,10 @@ class QuartetHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None, unexpected_keys: Optional[List[str]] = None,
): ):
from quartet_qat import QuartetLinear from fp_quant import FPQuantLinear
module, _ = get_module_from_name(model, param_name) module, _ = get_module_from_name(model, param_name)
assert isinstance(module, QuartetLinear), f"Module {param_name} is not a QuartetLinear somehow..." assert isinstance(module, FPQuantLinear), f"Module {param_name} is not a FPQuantLinear somehow..."
if param_name.endswith(".qweight"): if param_name.endswith(".qweight"):
module.qweight = torch.nn.Parameter( module.qweight = torch.nn.Parameter(
@ -116,34 +113,34 @@ class QuartetHfQuantizer(HfQuantizer):
model: "PreTrainedModel", model: "PreTrainedModel",
**kwargs, **kwargs,
): ):
from quartet_qat import replace_with_quartet_linear from fp_quant import replace_with_fp_quant_linear
from ..integrations.quartet import adapt_quartet_config from ..integrations.fp_quant import adapt_fp_quant_config
replace_with_quartet_linear( replace_with_fp_quant_linear(
model, model,
quartet_linear_config=adapt_quartet_config(self.quantization_config), fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
) )
model.config.quantization_config = self.quantization_config model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
from quartet_qat import QuartetLinear from fp_quant import FPQuantLinear
quartet_modules = {name: module for name, module in model.named_modules() if isinstance(module, QuartetLinear)} fp_quant_modules = {name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
for name, module in tqdm(quartet_modules.items(), desc="Pre-processing Quartet modules", leave=False): for name, module in fp_quant_modules.items():
if not self.quantization_config.store_master_weights and module.weight is not None: if not self.quantization_config.store_master_weights and module.weight is not None:
module.weight = None module.weight = None
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from quartet_qat import QuartetLinear from fp_quant import FPQuantLinear
quartet_names = {name for name, module in model.named_modules() if isinstance(module, QuartetLinear)} fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
def should_exclude(key: str) -> bool: def should_exclude(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"): if key.endswith(".weight") or key.endswith(".bias"):
return False return False
full_key = f"{prefix}.{key}" full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in quartet_names) return any(name in key or name in full_key for name in fp_quant_names)
return [key for key in missing_keys if not should_exclude(key)] return [key for key in missing_keys if not should_exclude(key)]
@ -162,11 +159,11 @@ class QuartetHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
**kwargs, **kwargs,
) -> bool: ) -> bool:
from quartet_qat import QuartetLinear from fp_quant import FPQuantLinear
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, QuartetLinear) and tensor_name in ["weight", "qweight"]: if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight"]:
# Only quantize weights of QuartetLinear modules that are not already quantized # Only quantize weights of FPQuantLinear modules that are not already quantized
return True return True
else: else:
return False return False

View File

@ -199,8 +199,7 @@ from .import_utils import (
is_pytest_available, is_pytest_available,
is_pytorch_quantization_available, is_pytorch_quantization_available,
is_quark_available, is_quark_available,
is_quartet_available, is_fp_quant_available,
is_quartet_qat_available,
is_qutlass_available, is_qutlass_available,
is_rich_available, is_rich_available,
is_rjieba_available, is_rjieba_available,

View File

@ -170,8 +170,7 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round",
# `importlib.metadata.version` doesn't work with `awq` # `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None _auto_awq_available = importlib.util.find_spec("awq") is not None
_quark_available = _is_package_available("quark") _quark_available = _is_package_available("quark")
_quartet_available = _is_package_available("quartet") _fp_quant_available = _is_package_available("fp_quant")
_quartet_qat_available = _is_package_available("quartet_qat")
_qutlass_available = _is_package_available("qutlass") _qutlass_available = _is_package_available("qutlass")
_is_optimum_quanto_available = False _is_optimum_quanto_available = False
try: try:
@ -1152,12 +1151,8 @@ def is_quark_available():
return _quark_available return _quark_available
def is_quartet_available(): def is_fp_quant_available():
return _quartet_available return _fp_quant_available
def is_quartet_qat_available():
return _quartet_qat_available
def is_qutlass_available(): def is_qutlass_available():

View File

@ -63,7 +63,7 @@ class QuantizationMethod(str, Enum):
SPQR = "spqr" SPQR = "spqr"
FP8 = "fp8" FP8 = "fp8"
QUARK = "quark" QUARK = "quark"
QUARTET = "quartet" FPQUANT = "fp_quant"
AUTOROUND = "auto-round" AUTOROUND = "auto-round"
@ -1552,9 +1552,9 @@ class HiggsConfig(QuantizationConfigMixin):
@dataclass @dataclass
class QuartetConfig(QuantizationConfigMixin): class FPQuantConfig(QuantizationConfigMixin):
""" """
QuartetConfig is a configuration class for quantization using the Quartet method. FPQuantConfig is a configuration class for quantization using the FPQuant method.
Args: Args:
forward_dtype (`str`, *optional*, defaults to `"mxfp4"`): forward_dtype (`str`, *optional*, defaults to `"mxfp4"`):
@ -1587,7 +1587,7 @@ class QuartetConfig(QuantizationConfigMixin):
self.hadamard_group_size = hadamard_group_size self.hadamard_group_size = hadamard_group_size
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
self.quant_method = QuantizationMethod.QUARTET self.quant_method = QuantizationMethod.FPQUANT
self.post_init() self.post_init()
def post_init(self): def post_init(self):