mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
quartet -> fp_quant
This commit is contained in:
parent
7199b608f2
commit
c2b5b29a8d
@ -93,9 +93,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
|||||||
|
|
||||||
[[autodoc]] QuarkConfig
|
[[autodoc]] QuarkConfig
|
||||||
|
|
||||||
## QuartetConfig
|
## FPQuantConfig
|
||||||
|
|
||||||
[[autodoc]] QuartetConfig
|
[[autodoc]] FPQuantConfig
|
||||||
|
|
||||||
## AutoRoundConfig
|
## AutoRoundConfig
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ _import_structure = {
|
|||||||
"HqqConfig",
|
"HqqConfig",
|
||||||
"QuantoConfig",
|
"QuantoConfig",
|
||||||
"QuarkConfig",
|
"QuarkConfig",
|
||||||
"QuartetConfig",
|
"FPQuantConfig",
|
||||||
"SpQRConfig",
|
"SpQRConfig",
|
||||||
"TorchAoConfig",
|
"TorchAoConfig",
|
||||||
"VptqConfig",
|
"VptqConfig",
|
||||||
@ -772,7 +772,7 @@ if TYPE_CHECKING:
|
|||||||
HqqConfig,
|
HqqConfig,
|
||||||
QuantoConfig,
|
QuantoConfig,
|
||||||
QuarkConfig,
|
QuarkConfig,
|
||||||
QuartetConfig,
|
FPQuantConfig,
|
||||||
SpQRConfig,
|
SpQRConfig,
|
||||||
TorchAoConfig,
|
TorchAoConfig,
|
||||||
VptqConfig,
|
VptqConfig,
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"Quartet integration file"
|
"FPQuant integration file"
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
is_quartet_qat_available,
|
is_fp_quant_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,27 +23,27 @@ if is_torch_available():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if is_quartet_qat_available():
|
if is_fp_quant_available():
|
||||||
from quartet_qat import QuartetConfig as QuartetLinearConfig
|
from fp_quant import FPQuantConfig as FPQuantLinearConfig
|
||||||
from quartet_qat import QuartetDtype
|
from fp_quant import FPQuantDtype
|
||||||
|
|
||||||
from transformers.utils.quantization_config import QuartetConfig
|
from transformers.utils.quantization_config import FPQuantConfig
|
||||||
|
|
||||||
|
|
||||||
def adapt_quartet_config(config: QuartetConfig):
|
def adapt_fp_quant_config(config: FPQuantConfig):
|
||||||
if config.forward_dtype == "mxfp4":
|
if config.forward_dtype == "mxfp4":
|
||||||
forward_dtype = QuartetDtype.MXFP4
|
forward_dtype = FPQuantDtype.MXFP4
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
|
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
|
||||||
|
|
||||||
if config.backward_dtype == "mxfp4":
|
if config.backward_dtype == "mxfp4":
|
||||||
backward_dtype = QuartetDtype.MXFP4
|
backward_dtype = FPQuantDtype.MXFP4
|
||||||
elif config.backward_dtype == "bf16":
|
elif config.backward_dtype == "bf16":
|
||||||
backward_dtype = QuartetDtype.BF16
|
backward_dtype = FPQuantDtype.BF16
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
|
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
|
||||||
|
|
||||||
return QuartetLinearConfig(
|
return FPQuantLinearConfig(
|
||||||
forward_dtype=forward_dtype,
|
forward_dtype=forward_dtype,
|
||||||
forward_method=config.forward_method,
|
forward_method=config.forward_method,
|
||||||
backward_dtype=backward_dtype,
|
backward_dtype=backward_dtype,
|
@ -34,7 +34,7 @@ from ..utils.quantization_config import (
|
|||||||
QuantizationMethod,
|
QuantizationMethod,
|
||||||
QuantoConfig,
|
QuantoConfig,
|
||||||
QuarkConfig,
|
QuarkConfig,
|
||||||
QuartetConfig,
|
FPQuantConfig,
|
||||||
SpQRConfig,
|
SpQRConfig,
|
||||||
TorchAoConfig,
|
TorchAoConfig,
|
||||||
VptqConfig,
|
VptqConfig,
|
||||||
@ -55,7 +55,7 @@ from .quantizer_higgs import HiggsHfQuantizer
|
|||||||
from .quantizer_hqq import HqqHfQuantizer
|
from .quantizer_hqq import HqqHfQuantizer
|
||||||
from .quantizer_quanto import QuantoHfQuantizer
|
from .quantizer_quanto import QuantoHfQuantizer
|
||||||
from .quantizer_quark import QuarkHfQuantizer
|
from .quantizer_quark import QuarkHfQuantizer
|
||||||
from .quantizer_quartet import QuartetHfQuantizer
|
from .quantizer_fp_quant import FPQuantHfQuantizer
|
||||||
from .quantizer_spqr import SpQRHfQuantizer
|
from .quantizer_spqr import SpQRHfQuantizer
|
||||||
from .quantizer_torchao import TorchAoHfQuantizer
|
from .quantizer_torchao import TorchAoHfQuantizer
|
||||||
from .quantizer_vptq import VptqHfQuantizer
|
from .quantizer_vptq import VptqHfQuantizer
|
||||||
@ -69,7 +69,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|||||||
"aqlm": AqlmHfQuantizer,
|
"aqlm": AqlmHfQuantizer,
|
||||||
"quanto": QuantoHfQuantizer,
|
"quanto": QuantoHfQuantizer,
|
||||||
"quark": QuarkHfQuantizer,
|
"quark": QuarkHfQuantizer,
|
||||||
"quartet": QuartetHfQuantizer,
|
"fp_quant": FPQuantHfQuantizer,
|
||||||
"eetq": EetqHfQuantizer,
|
"eetq": EetqHfQuantizer,
|
||||||
"higgs": HiggsHfQuantizer,
|
"higgs": HiggsHfQuantizer,
|
||||||
"hqq": HqqHfQuantizer,
|
"hqq": HqqHfQuantizer,
|
||||||
@ -92,7 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||||||
"aqlm": AqlmConfig,
|
"aqlm": AqlmConfig,
|
||||||
"quanto": QuantoConfig,
|
"quanto": QuantoConfig,
|
||||||
"quark": QuarkConfig,
|
"quark": QuarkConfig,
|
||||||
"quartet": QuartetConfig,
|
"fp_quant": FPQuantConfig,
|
||||||
"hqq": HqqConfig,
|
"hqq": HqqConfig,
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
"fbgemm_fp8": FbgemmFp8Config,
|
"fbgemm_fp8": FbgemmFp8Config,
|
||||||
|
@ -21,7 +21,7 @@ from .quantizers_utils import get_module_from_name
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from ..utils import is_quartet_available, is_quartet_qat_available, is_qutlass_available, is_torch_available, logging
|
from ..utils import is_fp_quant_available, is_fp_quant_available, is_qutlass_available, is_torch_available, logging
|
||||||
from ..utils.quantization_config import QuantizationConfigMixin
|
from ..utils.quantization_config import QuantizationConfigMixin
|
||||||
|
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ if is_torch_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class QuartetHfQuantizer(HfQuantizer):
|
class FPQuantHfQuantizer(HfQuantizer):
|
||||||
"""
|
"""
|
||||||
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
|
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
|
||||||
"""
|
"""
|
||||||
@ -39,7 +39,7 @@ class QuartetHfQuantizer(HfQuantizer):
|
|||||||
requires_calibration = False
|
requires_calibration = False
|
||||||
requires_parameters_quantization = True
|
requires_parameters_quantization = True
|
||||||
is_qat_trainable = True
|
is_qat_trainable = True
|
||||||
required_packages = ["qutlass", "quartet_qat", "quartet"]
|
required_packages = ["qutlass", "fp_quant"]
|
||||||
|
|
||||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||||
super().__init__(quantization_config, **kwargs)
|
super().__init__(quantization_config, **kwargs)
|
||||||
@ -48,36 +48,33 @@ class QuartetHfQuantizer(HfQuantizer):
|
|||||||
def validate_environment(self, device_map, **kwargs):
|
def validate_environment(self, device_map, **kwargs):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Quartet quantization is only supported on GPU. Please use a different quantizer."
|
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_qutlass_available():
|
if not is_qutlass_available():
|
||||||
raise ImportError("Using `quartet` quantization requires qutlass: `pip install qutlass`")
|
raise ImportError("Using `fp_quant` quantization requires qutlass: `pip install qutlass`")
|
||||||
|
|
||||||
if not is_quartet_qat_available():
|
if not is_fp_quant_available():
|
||||||
raise ImportError("Using `quartet` quantization requires quartet_qat: `pip install quartet_qat`")
|
raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
|
||||||
|
|
||||||
if not is_quartet_available():
|
|
||||||
raise ImportError("Using `quartet` quantization requires quartet: `pip install quartet`")
|
|
||||||
|
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are attempting to load a Quartet model without setting device_map."
|
"You are attempting to load a FPQuant model without setting device_map."
|
||||||
" Please set device_map comprised of 'cuda' devices."
|
" Please set device_map comprised of 'cuda' devices."
|
||||||
)
|
)
|
||||||
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are attempting to load a Quartet model with a device_map that contains a CPU or disk device."
|
"You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
|
||||||
" This is not supported. Please remove the CPU or disk device from the device_map."
|
" This is not supported. Please remove the CPU or disk device from the device_map."
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
if torch_dtype is None:
|
if torch_dtype is None:
|
||||||
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
|
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.")
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.bfloat16
|
||||||
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
|
elif torch_dtype != torch.bfloat16:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
|
f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`."
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
@ -91,10 +88,10 @@ class QuartetHfQuantizer(HfQuantizer):
|
|||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
unexpected_keys: Optional[List[str]] = None,
|
unexpected_keys: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
from quartet_qat import QuartetLinear
|
from fp_quant import FPQuantLinear
|
||||||
|
|
||||||
module, _ = get_module_from_name(model, param_name)
|
module, _ = get_module_from_name(model, param_name)
|
||||||
assert isinstance(module, QuartetLinear), f"Module {param_name} is not a QuartetLinear somehow..."
|
assert isinstance(module, FPQuantLinear), f"Module {param_name} is not a FPQuantLinear somehow..."
|
||||||
|
|
||||||
if param_name.endswith(".qweight"):
|
if param_name.endswith(".qweight"):
|
||||||
module.qweight = torch.nn.Parameter(
|
module.qweight = torch.nn.Parameter(
|
||||||
@ -116,34 +113,34 @@ class QuartetHfQuantizer(HfQuantizer):
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from quartet_qat import replace_with_quartet_linear
|
from fp_quant import replace_with_fp_quant_linear
|
||||||
|
|
||||||
from ..integrations.quartet import adapt_quartet_config
|
from ..integrations.fp_quant import adapt_fp_quant_config
|
||||||
|
|
||||||
replace_with_quartet_linear(
|
replace_with_fp_quant_linear(
|
||||||
model,
|
model,
|
||||||
quartet_linear_config=adapt_quartet_config(self.quantization_config),
|
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
|
||||||
)
|
)
|
||||||
model.config.quantization_config = self.quantization_config
|
model.config.quantization_config = self.quantization_config
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||||
from quartet_qat import QuartetLinear
|
from fp_quant import FPQuantLinear
|
||||||
|
|
||||||
quartet_modules = {name: module for name, module in model.named_modules() if isinstance(module, QuartetLinear)}
|
fp_quant_modules = {name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
|
||||||
for name, module in tqdm(quartet_modules.items(), desc="Pre-processing Quartet modules", leave=False):
|
for name, module in fp_quant_modules.items():
|
||||||
if not self.quantization_config.store_master_weights and module.weight is not None:
|
if not self.quantization_config.store_master_weights and module.weight is not None:
|
||||||
module.weight = None
|
module.weight = None
|
||||||
|
|
||||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||||
from quartet_qat import QuartetLinear
|
from fp_quant import FPQuantLinear
|
||||||
|
|
||||||
quartet_names = {name for name, module in model.named_modules() if isinstance(module, QuartetLinear)}
|
fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
|
||||||
|
|
||||||
def should_exclude(key: str) -> bool:
|
def should_exclude(key: str) -> bool:
|
||||||
if key.endswith(".weight") or key.endswith(".bias"):
|
if key.endswith(".weight") or key.endswith(".bias"):
|
||||||
return False
|
return False
|
||||||
full_key = f"{prefix}.{key}"
|
full_key = f"{prefix}.{key}"
|
||||||
return any(name in key or name in full_key for name in quartet_names)
|
return any(name in key or name in full_key for name in fp_quant_names)
|
||||||
|
|
||||||
return [key for key in missing_keys if not should_exclude(key)]
|
return [key for key in missing_keys if not should_exclude(key)]
|
||||||
|
|
||||||
@ -162,11 +159,11 @@ class QuartetHfQuantizer(HfQuantizer):
|
|||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
from quartet_qat import QuartetLinear
|
from fp_quant import FPQuantLinear
|
||||||
|
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
if isinstance(module, QuartetLinear) and tensor_name in ["weight", "qweight"]:
|
if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight"]:
|
||||||
# Only quantize weights of QuartetLinear modules that are not already quantized
|
# Only quantize weights of FPQuantLinear modules that are not already quantized
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
@ -199,8 +199,7 @@ from .import_utils import (
|
|||||||
is_pytest_available,
|
is_pytest_available,
|
||||||
is_pytorch_quantization_available,
|
is_pytorch_quantization_available,
|
||||||
is_quark_available,
|
is_quark_available,
|
||||||
is_quartet_available,
|
is_fp_quant_available,
|
||||||
is_quartet_qat_available,
|
|
||||||
is_qutlass_available,
|
is_qutlass_available,
|
||||||
is_rich_available,
|
is_rich_available,
|
||||||
is_rjieba_available,
|
is_rjieba_available,
|
||||||
|
@ -170,8 +170,7 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round",
|
|||||||
# `importlib.metadata.version` doesn't work with `awq`
|
# `importlib.metadata.version` doesn't work with `awq`
|
||||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||||
_quark_available = _is_package_available("quark")
|
_quark_available = _is_package_available("quark")
|
||||||
_quartet_available = _is_package_available("quartet")
|
_fp_quant_available = _is_package_available("fp_quant")
|
||||||
_quartet_qat_available = _is_package_available("quartet_qat")
|
|
||||||
_qutlass_available = _is_package_available("qutlass")
|
_qutlass_available = _is_package_available("qutlass")
|
||||||
_is_optimum_quanto_available = False
|
_is_optimum_quanto_available = False
|
||||||
try:
|
try:
|
||||||
@ -1152,12 +1151,8 @@ def is_quark_available():
|
|||||||
return _quark_available
|
return _quark_available
|
||||||
|
|
||||||
|
|
||||||
def is_quartet_available():
|
def is_fp_quant_available():
|
||||||
return _quartet_available
|
return _fp_quant_available
|
||||||
|
|
||||||
|
|
||||||
def is_quartet_qat_available():
|
|
||||||
return _quartet_qat_available
|
|
||||||
|
|
||||||
|
|
||||||
def is_qutlass_available():
|
def is_qutlass_available():
|
||||||
|
@ -63,7 +63,7 @@ class QuantizationMethod(str, Enum):
|
|||||||
SPQR = "spqr"
|
SPQR = "spqr"
|
||||||
FP8 = "fp8"
|
FP8 = "fp8"
|
||||||
QUARK = "quark"
|
QUARK = "quark"
|
||||||
QUARTET = "quartet"
|
FPQUANT = "fp_quant"
|
||||||
AUTOROUND = "auto-round"
|
AUTOROUND = "auto-round"
|
||||||
|
|
||||||
|
|
||||||
@ -1552,9 +1552,9 @@ class HiggsConfig(QuantizationConfigMixin):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QuartetConfig(QuantizationConfigMixin):
|
class FPQuantConfig(QuantizationConfigMixin):
|
||||||
"""
|
"""
|
||||||
QuartetConfig is a configuration class for quantization using the Quartet method.
|
FPQuantConfig is a configuration class for quantization using the FPQuant method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
forward_dtype (`str`, *optional*, defaults to `"mxfp4"`):
|
forward_dtype (`str`, *optional*, defaults to `"mxfp4"`):
|
||||||
@ -1587,7 +1587,7 @@ class QuartetConfig(QuantizationConfigMixin):
|
|||||||
self.hadamard_group_size = hadamard_group_size
|
self.hadamard_group_size = hadamard_group_size
|
||||||
self.modules_to_not_convert = modules_to_not_convert
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
|
||||||
self.quant_method = QuantizationMethod.QUARTET
|
self.quant_method = QuantizationMethod.FPQUANT
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user