quartet -> fp_quant

This commit is contained in:
Andrei Panferov 2025-07-01 10:29:15 +02:00
parent 7199b608f2
commit c2b5b29a8d
8 changed files with 55 additions and 64 deletions

View File

@ -93,9 +93,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] QuarkConfig
## QuartetConfig
## FPQuantConfig
[[autodoc]] QuartetConfig
[[autodoc]] FPQuantConfig
## AutoRoundConfig

View File

@ -272,7 +272,7 @@ _import_structure = {
"HqqConfig",
"QuantoConfig",
"QuarkConfig",
"QuartetConfig",
"FPQuantConfig",
"SpQRConfig",
"TorchAoConfig",
"VptqConfig",
@ -772,7 +772,7 @@ if TYPE_CHECKING:
HqqConfig,
QuantoConfig,
QuarkConfig,
QuartetConfig,
FPQuantConfig,
SpQRConfig,
TorchAoConfig,
VptqConfig,

View File

@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"Quartet integration file"
"FPQuant integration file"
from ..utils import (
is_quartet_qat_available,
is_fp_quant_available,
is_torch_available,
)
@ -23,27 +23,27 @@ if is_torch_available():
pass
if is_quartet_qat_available():
from quartet_qat import QuartetConfig as QuartetLinearConfig
from quartet_qat import QuartetDtype
if is_fp_quant_available():
from fp_quant import FPQuantConfig as FPQuantLinearConfig
from fp_quant import FPQuantDtype
from transformers.utils.quantization_config import QuartetConfig
from transformers.utils.quantization_config import FPQuantConfig
def adapt_quartet_config(config: QuartetConfig):
def adapt_fp_quant_config(config: FPQuantConfig):
if config.forward_dtype == "mxfp4":
forward_dtype = QuartetDtype.MXFP4
forward_dtype = FPQuantDtype.MXFP4
else:
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
if config.backward_dtype == "mxfp4":
backward_dtype = QuartetDtype.MXFP4
backward_dtype = FPQuantDtype.MXFP4
elif config.backward_dtype == "bf16":
backward_dtype = QuartetDtype.BF16
backward_dtype = FPQuantDtype.BF16
else:
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
return QuartetLinearConfig(
return FPQuantLinearConfig(
forward_dtype=forward_dtype,
forward_method=config.forward_method,
backward_dtype=backward_dtype,

View File

@ -34,7 +34,7 @@ from ..utils.quantization_config import (
QuantizationMethod,
QuantoConfig,
QuarkConfig,
QuartetConfig,
FPQuantConfig,
SpQRConfig,
TorchAoConfig,
VptqConfig,
@ -55,7 +55,7 @@ from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_quark import QuarkHfQuantizer
from .quantizer_quartet import QuartetHfQuantizer
from .quantizer_fp_quant import FPQuantHfQuantizer
from .quantizer_spqr import SpQRHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer
@ -69,7 +69,7 @@ AUTO_QUANTIZER_MAPPING = {
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"quark": QuarkHfQuantizer,
"quartet": QuartetHfQuantizer,
"fp_quant": FPQuantHfQuantizer,
"eetq": EetqHfQuantizer,
"higgs": HiggsHfQuantizer,
"hqq": HqqHfQuantizer,
@ -92,7 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"quark": QuarkConfig,
"quartet": QuartetConfig,
"fp_quant": FPQuantConfig,
"hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,

View File

@ -21,7 +21,7 @@ from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_quartet_available, is_quartet_qat_available, is_qutlass_available, is_torch_available, logging
from ..utils import is_fp_quant_available, is_fp_quant_available, is_qutlass_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
@ -31,7 +31,7 @@ if is_torch_available():
logger = logging.get_logger(__name__)
class QuartetHfQuantizer(HfQuantizer):
class FPQuantHfQuantizer(HfQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
"""
@ -39,7 +39,7 @@ class QuartetHfQuantizer(HfQuantizer):
requires_calibration = False
requires_parameters_quantization = True
is_qat_trainable = True
required_packages = ["qutlass", "quartet_qat", "quartet"]
required_packages = ["qutlass", "fp_quant"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
@ -48,36 +48,33 @@ class QuartetHfQuantizer(HfQuantizer):
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise NotImplementedError(
"Quartet quantization is only supported on GPU. Please use a different quantizer."
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
)
if not is_qutlass_available():
raise ImportError("Using `quartet` quantization requires qutlass: `pip install qutlass`")
raise ImportError("Using `fp_quant` quantization requires qutlass: `pip install qutlass`")
if not is_quartet_qat_available():
raise ImportError("Using `quartet` quantization requires quartet_qat: `pip install quartet_qat`")
if not is_quartet_available():
raise ImportError("Using `quartet` quantization requires quartet: `pip install quartet`")
if not is_fp_quant_available():
raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
if device_map is None:
raise ValueError(
"You are attempting to load a Quartet model without setting device_map."
"You are attempting to load a FPQuant model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load a Quartet model with a device_map that contains a CPU or disk device."
"You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
torch_dtype = torch.float16
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.")
torch_dtype = torch.bfloat16
elif torch_dtype != torch.bfloat16:
raise ValueError(
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`."
)
return torch_dtype
@ -91,10 +88,10 @@ class QuartetHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
from quartet_qat import QuartetLinear
from fp_quant import FPQuantLinear
module, _ = get_module_from_name(model, param_name)
assert isinstance(module, QuartetLinear), f"Module {param_name} is not a QuartetLinear somehow..."
assert isinstance(module, FPQuantLinear), f"Module {param_name} is not a FPQuantLinear somehow..."
if param_name.endswith(".qweight"):
module.qweight = torch.nn.Parameter(
@ -116,34 +113,34 @@ class QuartetHfQuantizer(HfQuantizer):
model: "PreTrainedModel",
**kwargs,
):
from quartet_qat import replace_with_quartet_linear
from fp_quant import replace_with_fp_quant_linear
from ..integrations.quartet import adapt_quartet_config
from ..integrations.fp_quant import adapt_fp_quant_config
replace_with_quartet_linear(
replace_with_fp_quant_linear(
model,
quartet_linear_config=adapt_quartet_config(self.quantization_config),
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
from quartet_qat import QuartetLinear
from fp_quant import FPQuantLinear
quartet_modules = {name: module for name, module in model.named_modules() if isinstance(module, QuartetLinear)}
for name, module in tqdm(quartet_modules.items(), desc="Pre-processing Quartet modules", leave=False):
fp_quant_modules = {name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
for name, module in fp_quant_modules.items():
if not self.quantization_config.store_master_weights and module.weight is not None:
module.weight = None
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from quartet_qat import QuartetLinear
from fp_quant import FPQuantLinear
quartet_names = {name for name, module in model.named_modules() if isinstance(module, QuartetLinear)}
fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
def should_exclude(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"):
return False
full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in quartet_names)
return any(name in key or name in full_key for name in fp_quant_names)
return [key for key in missing_keys if not should_exclude(key)]
@ -162,11 +159,11 @@ class QuartetHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
from quartet_qat import QuartetLinear
from fp_quant import FPQuantLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, QuartetLinear) and tensor_name in ["weight", "qweight"]:
# Only quantize weights of QuartetLinear modules that are not already quantized
if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight"]:
# Only quantize weights of FPQuantLinear modules that are not already quantized
return True
else:
return False

View File

@ -199,8 +199,7 @@ from .import_utils import (
is_pytest_available,
is_pytorch_quantization_available,
is_quark_available,
is_quartet_available,
is_quartet_qat_available,
is_fp_quant_available,
is_qutlass_available,
is_rich_available,
is_rjieba_available,

View File

@ -170,8 +170,7 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round",
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quark_available = _is_package_available("quark")
_quartet_available = _is_package_available("quartet")
_quartet_qat_available = _is_package_available("quartet_qat")
_fp_quant_available = _is_package_available("fp_quant")
_qutlass_available = _is_package_available("qutlass")
_is_optimum_quanto_available = False
try:
@ -1152,12 +1151,8 @@ def is_quark_available():
return _quark_available
def is_quartet_available():
return _quartet_available
def is_quartet_qat_available():
return _quartet_qat_available
def is_fp_quant_available():
return _fp_quant_available
def is_qutlass_available():

View File

@ -63,7 +63,7 @@ class QuantizationMethod(str, Enum):
SPQR = "spqr"
FP8 = "fp8"
QUARK = "quark"
QUARTET = "quartet"
FPQUANT = "fp_quant"
AUTOROUND = "auto-round"
@ -1552,9 +1552,9 @@ class HiggsConfig(QuantizationConfigMixin):
@dataclass
class QuartetConfig(QuantizationConfigMixin):
class FPQuantConfig(QuantizationConfigMixin):
"""
QuartetConfig is a configuration class for quantization using the Quartet method.
FPQuantConfig is a configuration class for quantization using the FPQuant method.
Args:
forward_dtype (`str`, *optional*, defaults to `"mxfp4"`):
@ -1587,7 +1587,7 @@ class QuartetConfig(QuantizationConfigMixin):
self.hadamard_group_size = hadamard_group_size
self.modules_to_not_convert = modules_to_not_convert
self.quant_method = QuantizationMethod.QUARTET
self.quant_method = QuantizationMethod.FPQUANT
self.post_init()
def post_init(self):