mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 07:18:58 +06:00
quartet -> fp_quant
This commit is contained in:
parent
7199b608f2
commit
c2b5b29a8d
@ -93,9 +93,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
|
||||
[[autodoc]] QuarkConfig
|
||||
|
||||
## QuartetConfig
|
||||
## FPQuantConfig
|
||||
|
||||
[[autodoc]] QuartetConfig
|
||||
[[autodoc]] FPQuantConfig
|
||||
|
||||
## AutoRoundConfig
|
||||
|
||||
|
@ -272,7 +272,7 @@ _import_structure = {
|
||||
"HqqConfig",
|
||||
"QuantoConfig",
|
||||
"QuarkConfig",
|
||||
"QuartetConfig",
|
||||
"FPQuantConfig",
|
||||
"SpQRConfig",
|
||||
"TorchAoConfig",
|
||||
"VptqConfig",
|
||||
@ -772,7 +772,7 @@ if TYPE_CHECKING:
|
||||
HqqConfig,
|
||||
QuantoConfig,
|
||||
QuarkConfig,
|
||||
QuartetConfig,
|
||||
FPQuantConfig,
|
||||
SpQRConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
|
@ -11,10 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"Quartet integration file"
|
||||
"FPQuant integration file"
|
||||
|
||||
from ..utils import (
|
||||
is_quartet_qat_available,
|
||||
is_fp_quant_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
@ -23,27 +23,27 @@ if is_torch_available():
|
||||
pass
|
||||
|
||||
|
||||
if is_quartet_qat_available():
|
||||
from quartet_qat import QuartetConfig as QuartetLinearConfig
|
||||
from quartet_qat import QuartetDtype
|
||||
if is_fp_quant_available():
|
||||
from fp_quant import FPQuantConfig as FPQuantLinearConfig
|
||||
from fp_quant import FPQuantDtype
|
||||
|
||||
from transformers.utils.quantization_config import QuartetConfig
|
||||
from transformers.utils.quantization_config import FPQuantConfig
|
||||
|
||||
|
||||
def adapt_quartet_config(config: QuartetConfig):
|
||||
def adapt_fp_quant_config(config: FPQuantConfig):
|
||||
if config.forward_dtype == "mxfp4":
|
||||
forward_dtype = QuartetDtype.MXFP4
|
||||
forward_dtype = FPQuantDtype.MXFP4
|
||||
else:
|
||||
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
|
||||
|
||||
if config.backward_dtype == "mxfp4":
|
||||
backward_dtype = QuartetDtype.MXFP4
|
||||
backward_dtype = FPQuantDtype.MXFP4
|
||||
elif config.backward_dtype == "bf16":
|
||||
backward_dtype = QuartetDtype.BF16
|
||||
backward_dtype = FPQuantDtype.BF16
|
||||
else:
|
||||
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
|
||||
|
||||
return QuartetLinearConfig(
|
||||
return FPQuantLinearConfig(
|
||||
forward_dtype=forward_dtype,
|
||||
forward_method=config.forward_method,
|
||||
backward_dtype=backward_dtype,
|
@ -34,7 +34,7 @@ from ..utils.quantization_config import (
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
QuarkConfig,
|
||||
QuartetConfig,
|
||||
FPQuantConfig,
|
||||
SpQRConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
@ -55,7 +55,7 @@ from .quantizer_higgs import HiggsHfQuantizer
|
||||
from .quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
from .quantizer_quark import QuarkHfQuantizer
|
||||
from .quantizer_quartet import QuartetHfQuantizer
|
||||
from .quantizer_fp_quant import FPQuantHfQuantizer
|
||||
from .quantizer_spqr import SpQRHfQuantizer
|
||||
from .quantizer_torchao import TorchAoHfQuantizer
|
||||
from .quantizer_vptq import VptqHfQuantizer
|
||||
@ -69,7 +69,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"aqlm": AqlmHfQuantizer,
|
||||
"quanto": QuantoHfQuantizer,
|
||||
"quark": QuarkHfQuantizer,
|
||||
"quartet": QuartetHfQuantizer,
|
||||
"fp_quant": FPQuantHfQuantizer,
|
||||
"eetq": EetqHfQuantizer,
|
||||
"higgs": HiggsHfQuantizer,
|
||||
"hqq": HqqHfQuantizer,
|
||||
@ -92,7 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"aqlm": AqlmConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"quark": QuarkConfig,
|
||||
"quartet": QuartetConfig,
|
||||
"fp_quant": FPQuantConfig,
|
||||
"hqq": HqqConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"fbgemm_fp8": FbgemmFp8Config,
|
||||
|
@ -21,7 +21,7 @@ from .quantizers_utils import get_module_from_name
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_quartet_available, is_quartet_qat_available, is_qutlass_available, is_torch_available, logging
|
||||
from ..utils import is_fp_quant_available, is_fp_quant_available, is_qutlass_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import QuantizationConfigMixin
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuartetHfQuantizer(HfQuantizer):
|
||||
class FPQuantHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
|
||||
"""
|
||||
@ -39,7 +39,7 @@ class QuartetHfQuantizer(HfQuantizer):
|
||||
requires_calibration = False
|
||||
requires_parameters_quantization = True
|
||||
is_qat_trainable = True
|
||||
required_packages = ["qutlass", "quartet_qat", "quartet"]
|
||||
required_packages = ["qutlass", "fp_quant"]
|
||||
|
||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
@ -48,36 +48,33 @@ class QuartetHfQuantizer(HfQuantizer):
|
||||
def validate_environment(self, device_map, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
raise NotImplementedError(
|
||||
"Quartet quantization is only supported on GPU. Please use a different quantizer."
|
||||
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
|
||||
)
|
||||
|
||||
if not is_qutlass_available():
|
||||
raise ImportError("Using `quartet` quantization requires qutlass: `pip install qutlass`")
|
||||
raise ImportError("Using `fp_quant` quantization requires qutlass: `pip install qutlass`")
|
||||
|
||||
if not is_quartet_qat_available():
|
||||
raise ImportError("Using `quartet` quantization requires quartet_qat: `pip install quartet_qat`")
|
||||
|
||||
if not is_quartet_available():
|
||||
raise ImportError("Using `quartet` quantization requires quartet: `pip install quartet`")
|
||||
if not is_fp_quant_available():
|
||||
raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
|
||||
|
||||
if device_map is None:
|
||||
raise ValueError(
|
||||
"You are attempting to load a Quartet model without setting device_map."
|
||||
"You are attempting to load a FPQuant model without setting device_map."
|
||||
" Please set device_map comprised of 'cuda' devices."
|
||||
)
|
||||
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"You are attempting to load a Quartet model with a device_map that contains a CPU or disk device."
|
||||
"You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
|
||||
" This is not supported. Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
|
||||
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.")
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
|
||||
f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
|
||||
return torch_dtype
|
||||
@ -91,10 +88,10 @@ class QuartetHfQuantizer(HfQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
):
|
||||
from quartet_qat import QuartetLinear
|
||||
from fp_quant import FPQuantLinear
|
||||
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
assert isinstance(module, QuartetLinear), f"Module {param_name} is not a QuartetLinear somehow..."
|
||||
assert isinstance(module, FPQuantLinear), f"Module {param_name} is not a FPQuantLinear somehow..."
|
||||
|
||||
if param_name.endswith(".qweight"):
|
||||
module.qweight = torch.nn.Parameter(
|
||||
@ -116,34 +113,34 @@ class QuartetHfQuantizer(HfQuantizer):
|
||||
model: "PreTrainedModel",
|
||||
**kwargs,
|
||||
):
|
||||
from quartet_qat import replace_with_quartet_linear
|
||||
from fp_quant import replace_with_fp_quant_linear
|
||||
|
||||
from ..integrations.quartet import adapt_quartet_config
|
||||
from ..integrations.fp_quant import adapt_fp_quant_config
|
||||
|
||||
replace_with_quartet_linear(
|
||||
replace_with_fp_quant_linear(
|
||||
model,
|
||||
quartet_linear_config=adapt_quartet_config(self.quantization_config),
|
||||
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
from quartet_qat import QuartetLinear
|
||||
from fp_quant import FPQuantLinear
|
||||
|
||||
quartet_modules = {name: module for name, module in model.named_modules() if isinstance(module, QuartetLinear)}
|
||||
for name, module in tqdm(quartet_modules.items(), desc="Pre-processing Quartet modules", leave=False):
|
||||
fp_quant_modules = {name: module for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
|
||||
for name, module in fp_quant_modules.items():
|
||||
if not self.quantization_config.store_master_weights and module.weight is not None:
|
||||
module.weight = None
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
from quartet_qat import QuartetLinear
|
||||
from fp_quant import FPQuantLinear
|
||||
|
||||
quartet_names = {name for name, module in model.named_modules() if isinstance(module, QuartetLinear)}
|
||||
fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
|
||||
|
||||
def should_exclude(key: str) -> bool:
|
||||
if key.endswith(".weight") or key.endswith(".bias"):
|
||||
return False
|
||||
full_key = f"{prefix}.{key}"
|
||||
return any(name in key or name in full_key for name in quartet_names)
|
||||
return any(name in key or name in full_key for name in fp_quant_names)
|
||||
|
||||
return [key for key in missing_keys if not should_exclude(key)]
|
||||
|
||||
@ -162,11 +159,11 @@ class QuartetHfQuantizer(HfQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
from quartet_qat import QuartetLinear
|
||||
from fp_quant import FPQuantLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module, QuartetLinear) and tensor_name in ["weight", "qweight"]:
|
||||
# Only quantize weights of QuartetLinear modules that are not already quantized
|
||||
if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight"]:
|
||||
# Only quantize weights of FPQuantLinear modules that are not already quantized
|
||||
return True
|
||||
else:
|
||||
return False
|
@ -199,8 +199,7 @@ from .import_utils import (
|
||||
is_pytest_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_quark_available,
|
||||
is_quartet_available,
|
||||
is_quartet_qat_available,
|
||||
is_fp_quant_available,
|
||||
is_qutlass_available,
|
||||
is_rich_available,
|
||||
is_rjieba_available,
|
||||
|
@ -170,8 +170,7 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round",
|
||||
# `importlib.metadata.version` doesn't work with `awq`
|
||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||
_quark_available = _is_package_available("quark")
|
||||
_quartet_available = _is_package_available("quartet")
|
||||
_quartet_qat_available = _is_package_available("quartet_qat")
|
||||
_fp_quant_available = _is_package_available("fp_quant")
|
||||
_qutlass_available = _is_package_available("qutlass")
|
||||
_is_optimum_quanto_available = False
|
||||
try:
|
||||
@ -1152,12 +1151,8 @@ def is_quark_available():
|
||||
return _quark_available
|
||||
|
||||
|
||||
def is_quartet_available():
|
||||
return _quartet_available
|
||||
|
||||
|
||||
def is_quartet_qat_available():
|
||||
return _quartet_qat_available
|
||||
def is_fp_quant_available():
|
||||
return _fp_quant_available
|
||||
|
||||
|
||||
def is_qutlass_available():
|
||||
|
@ -63,7 +63,7 @@ class QuantizationMethod(str, Enum):
|
||||
SPQR = "spqr"
|
||||
FP8 = "fp8"
|
||||
QUARK = "quark"
|
||||
QUARTET = "quartet"
|
||||
FPQUANT = "fp_quant"
|
||||
AUTOROUND = "auto-round"
|
||||
|
||||
|
||||
@ -1552,9 +1552,9 @@ class HiggsConfig(QuantizationConfigMixin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuartetConfig(QuantizationConfigMixin):
|
||||
class FPQuantConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
QuartetConfig is a configuration class for quantization using the Quartet method.
|
||||
FPQuantConfig is a configuration class for quantization using the FPQuant method.
|
||||
|
||||
Args:
|
||||
forward_dtype (`str`, *optional*, defaults to `"mxfp4"`):
|
||||
@ -1587,7 +1587,7 @@ class QuartetConfig(QuantizationConfigMixin):
|
||||
self.hadamard_group_size = hadamard_group_size
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
self.quant_method = QuantizationMethod.QUARTET
|
||||
self.quant_method = QuantizationMethod.FPQUANT
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
|
Loading…
Reference in New Issue
Block a user