Efficient Inference Kernel for SpQR (#34976)

* Resolve vptq conflict

* Rename spqr package to spqr_quant

* Get rid of aqlm mention

* Start working on tests

* Resolve ruff code checks

* Ruff format

* Isort

* Test updates

* Add gpu tag

* Rename to modules_to_not_convert

* Config update

* Docs and config update

* Docs and config update

* Update to update_torch_dtype

* spqr config parameter validation

* Ruff update

* Apply ruff fixes

* Test fixes

* Ruff update

* Mark tests as @slow again; Ruff; Docstring update

* Ruff

* Remove absolute path

* Resolve typo

* Remove redundandt log

* Check accelerate/spqr availability

* Ruff fix

* Check if the config contains proper shapes

* Ruff test

* Documentation update

* overview update

* Ruff checks

* Ruff code quality

* Make style

* Update docs/source/en/quantization/spqr.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update spqr.md

* Enable gptqmodel (#35012)

* gptqmodel

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update readme

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* gptqmodel need use checkpoint_format (#1)

* gptqmodel need use checkpoint_format

* fix quantize

* Update quantization_config.py

* Update quantization_config.py

* Update quantization_config.py

---------

Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>

* Revert quantizer_gptq.py (#2)

* revert quantizer_gptq.py change

* pass **kwargs

* limit gptqmodel and optimum version

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix warning

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix version check

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert unrelated changes

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable gptqmodel tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix requires gptq

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix Transformer compat (#3)

* revert quantizer_gptq.py change

* pass **kwargs

* add meta info

* cleanup

* cleanup

* Update quantization_config.py

* hf_select_quant_linear pass checkpoint_format and meta

* fix GPTQTestCUDA

* Update test_gptq.py

* gptqmodel.hf_select_quant_linear() now does not select ExllamaV2

* cleanup

* add backend

* cleanup

* cleanup

* no need check exllama version

* Update quantization_config.py

* lower checkpoint_format and backend

* check none

* cleanup

* Update quantization_config.py

* fix self.use_exllama == False

* spell

* fix unittest

* fix unittest

---------

Co-authored-by: LRL <lrl@lbx.dev>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format again

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update gptqmodel version (#6)

* update gptqmodel version

* update gptqmodel version

* fix unit test (#5)

* update gptqmodel version

* update gptqmodel version

* "not self.use_exllama" is not equivalent to "self.use_exllama==False"

* fix unittest

* update gptqmodel version

* backend is loading_attibutes (#7)

* fix format and tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix memory check

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix device mismatch

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix result check

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Update src/transformers/quantizers/quantizer_gptq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_gptq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_gptq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* update tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* review: update docs (#10)

* review: update docs (#12)

* review: update docs

* fix typo

* update tests for gptqmodel

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update document (#9)

* update overview.md

* cleanup

* Update overview.md

* Update overview.md

* Update overview.md

* update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

---------

Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>

* typo

* doc note for asymmetric quant

* typo with apple silicon(e)

* typo for marlin

* column name revert: review

* doc rocm support

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/overview.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/overview.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com>
Co-authored-by: LRL <lrl@lbx.dev>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Fix : Nemotron Processor in GGUF conversion (#35708)

* fixing nemotron processor

* make style

* Update docs/source/en/quantization/spqr.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add missing TOC to doc

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com>
Co-authored-by: LRL <lrl@lbx.dev>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Elvir Crnčević 2025-02-13 16:22:58 +01:00 committed by GitHub
parent c5506f4f00
commit 845b0a2616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 591 additions and 0 deletions

View File

@ -53,6 +53,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add vptq for quantization testing
RUN python3 -m pip install --no-cache-dir vptq
# Add spqr for quantization testing
RUN python3 -m pip install --no-cache-dir spqr_quant[gpu]
# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq

View File

@ -166,6 +166,8 @@
- local: quantization/aqlm
title: AQLM
- local: quantization/vptq
title: SpQR
- local: quantization/spqr
title: VPTQ
- local: quantization/quanto
title: Quanto

View File

@ -81,6 +81,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] BitNetConfig
## SpQRConfig
[[autodoc]] SpQRConfig
## FineGrainedFP8Config
[[autodoc]] FineGrainedFP8Config

View File

@ -61,6 +61,7 @@ Use the table below to help you decide which quantization method to use.
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
<Tip>

View File

@ -0,0 +1,35 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# SpQR
[SpQR](https://github.com/Vahe1994/SpQR) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure, with sparse outliers as detailed in [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression](https://arxiv.org/abs/2306.03078).
To SpQR-quantize a model, refer to the [Vahe1994/SpQR](https://github.com/Vahe1994/SpQR) repository.
Load a pre-SpQR-quantized model in [`~PreTrainedModel.from_pretrained`].
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
quantized_model = AutoModelForCausalLM.from_pretrained(
"elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf",
torch_dtype=torch.half,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf")
```

View File

@ -1029,6 +1029,7 @@ _import_structure = {
"HiggsConfig",
"HqqConfig",
"QuantoConfig",
"SpQRConfig",
"TorchAoConfig",
"VptqConfig",
],
@ -6202,6 +6203,7 @@ if TYPE_CHECKING:
HiggsConfig,
HqqConfig,
QuantoConfig,
SpQRConfig,
TorchAoConfig,
VptqConfig,
)

View File

@ -106,6 +106,7 @@ _import_structure = {
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
"spqr": ["replace_with_spqr_linear"],
"vptq": ["replace_with_vptq_linear"],
}
@ -210,6 +211,7 @@ if TYPE_CHECKING:
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
from .spqr import replace_with_spqr_linear
from .vptq import replace_with_vptq_linear
try:

View File

@ -0,0 +1,122 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"SpQR (Sparse-Quantized Representation) integration file"
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available
if is_torch_available():
import torch.nn as nn
def replace_with_spqr_linear(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successful or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`SpQRConfig`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`list[str]`, *optional*):
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
converted.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
if is_accelerate_available():
from accelerate import init_empty_weights
if is_spqr_available():
from spqr_quant import QuantizedLinear
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear):
# Check if the current key is not in the `modules_to_not_convert`
if ".".join(current_key_name) + ".weight" not in modules_to_not_convert:
with init_empty_weights():
tensor_name = ".".join(current_key_name)
shapes = quantization_config.shapes
shapes_keys = shapes.keys()
shapes_valid = (
f"{tensor_name}.dense_weights.shape" in shapes_keys
and f"{tensor_name}.row_offsets.shape" in shapes_keys
and f"{tensor_name}.col_vals.shape" in shapes_keys
and f"{tensor_name}.in_perm.shape" in shapes_keys
)
if not shapes_valid:
raise ValueError(
f"The SpQR quantization config does not contain the shape "
f"configuration for {tensor_name}. This indicates that the "
f"configuration is either invalid or corrupted."
)
dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"]
row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]
col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"]
in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"]
in_features = module.in_features
out_features = module.out_features
model._modules[name] = QuantizedLinear.create_placehodler(
rows=out_features,
cols=in_features,
bits=quantization_config.bits,
beta1=quantization_config.beta1,
beta2=quantization_config.beta2,
dense_weights_shape=dense_weights_shape,
row_offsets_shape=row_offsets_shape,
col_vals_shape=col_vals_shape,
in_perm_shape=in_perm_shape,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
else:
pass
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_spqr_linear(
module,
quantization_config=quantization_config,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced

View File

@ -31,6 +31,7 @@ from ..utils.quantization_config import (
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
SpQRConfig,
TorchAoConfig,
VptqConfig,
)
@ -47,6 +48,7 @@ from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_spqr import SpQRHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer
@ -66,6 +68,7 @@ AUTO_QUANTIZER_MAPPING = {
"torchao": TorchAoHfQuantizer,
"bitnet": BitNetHfQuantizer,
"vptq": VptqHfQuantizer,
"spqr": SpQRHfQuantizer,
"fp8": FineGrainedFP8HfQuantizer,
}
@ -84,6 +87,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
"vptq": VptqConfig,
"spqr": SpQRConfig,
"fp8": FineGrainedFP8Config,
}

View File

@ -0,0 +1,83 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/lic enses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..integrations import replace_with_spqr_linear
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class SpQRHfQuantizer(HfQuantizer):
"""
Quantizer of the SpQR method. Enables the loading of prequantized models.
"""
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run SpQR quantized model.")
if not is_accelerate_available():
raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`")
if not is_spqr_available():
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`")
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.float16
logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.")
elif torch_dtype != torch.float16:
raise ValueError(
"You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to"
"torch.float16 explicitly."
)
return torch_dtype
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
replace_with_spqr_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
def is_serializable(self, safe_serialization=None):
return True

View File

@ -121,6 +121,7 @@ from .utils import (
is_seqio_available,
is_soundfile_available,
is_spacy_available,
is_spqr_available,
is_sudachi_available,
is_sudachi_projection_available,
is_tensorflow_probability_available,
@ -1191,6 +1192,13 @@ def require_vptq(test_case):
return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
def require_spqr(test_case):
"""
Decorator marking a test that requires spqr
"""
return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
def require_eetq(test_case):
"""
Decorator marking a test that requires eetq

View File

@ -193,6 +193,7 @@ from .import_utils import (
is_soundfile_available,
is_spacy_available,
is_speech_available,
is_spqr_available,
is_sudachi_available,
is_sudachi_projection_available,
is_tensorflow_probability_available,

View File

@ -201,6 +201,7 @@ _tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")
_triton_available = _is_package_available("triton")
_spqr_available = _is_package_available("spqr_quant")
_torch_version = "N/A"
_torch_available = False
@ -1213,6 +1214,10 @@ def is_speech_available():
return _torchaudio_available
def is_spqr_available():
return _spqr_available
def is_phonemizer_available():
return _phonemizer_available

View File

@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum):
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"
BITNET = "bitnet"
SPQR = "spqr"
FP8 = "fp8"
@ -1551,6 +1552,75 @@ class BitNetConfig(QuantizationConfigMixin):
pass
@dataclass
class SpQRConfig(QuantizationConfigMixin):
"""
This is a wrapper class about `spqr` parameters. Refer to the original publication for more details.
Args:
bits (`int`, *optional*, defaults to 3):
Specifies the bit count for the weights and first order zero-points and scales.
Currently only bits = 3 is supported.
beta1 (`int`, *optional*, defaults to 16):
SpQR tile width. Currently only beta1 = 16 is supported.
beta2 (`int`, *optional*, defaults to 16):
SpQR tile height. Currently only beta2 = 16 is supported.
shapes (`Optional`, *optional*):
A dictionary holding the shape of each object. We need this because it's impossible
to deduce the exact size of the parameters just from bits, beta1, beta2.
modules_to_not_convert (`Optional[List[str]]`, *optional*):
Optionally, provides a list of full paths of `nn.Linear` weight parameters that shall not be quantized.
Defaults to None.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
def __init__(
self,
bits: int = 3,
beta1: int = 16,
beta2: int = 16,
shapes: Optional[Dict[str, int]] = None,
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
):
if shapes is None:
shapes = {}
self.shapes = shapes
self.quant_method = QuantizationMethod.SPQR
self.bits = bits
self.beta1 = beta1
self.beta2 = beta2
if modules_to_not_convert is None:
modules_to_not_convert = []
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.bits, int):
raise TypeError("bits must be an int")
if not isinstance(self.beta1, int):
raise TypeError("beta1 must be an int")
if not isinstance(self.beta2, int):
raise TypeError("beta2 must be an int")
if self.bits != 3:
raise ValueError("SpQR currently only supports bits = 3")
if self.beta1 != 16:
raise ValueError("SpQR currently only supports beta1 = 16")
if self.beta2 != 16:
raise ValueError("SpQR currently only supports beta2 = 16")
if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list):
raise ValueError("modules_to_not_convert must be a list of strings")
if not isinstance(self.shapes, dict):
raise TypeError("shapes must be a dict")
@dataclass
class FineGrainedFP8Config(QuantizationConfigMixin):
"""

View File

@ -0,0 +1,249 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, SpQRConfig, StaticCache
from transformers.testing_utils import (
require_accelerate,
require_spqr,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class SpQRConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = SpQRConfig()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {
"beta1": 16,
"beta2": 16,
"bits": 3,
"modules_to_not_convert": ["lm_head.weight"],
"shapes": {"model.layers.0.self_attn.q_proj.dense_weights.shape": 16},
}
quantization_config = SpQRConfig.from_dict(dict)
self.assertEqual(dict["beta1"], quantization_config.beta1)
self.assertEqual(dict["beta2"], quantization_config.beta2)
self.assertEqual(dict["bits"], quantization_config.bits)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["shapes"], quantization_config.shapes)
@slow
@require_torch_gpu
@require_spqr
@require_accelerate
class SpQRTest(unittest.TestCase):
model_name = "elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf"
input_text = "Hello my name is"
max_new_tokens = 32
EXPECTED_OUTPUT = (
"Hello my name is Jesse. (I'm also known as Jesse) I'm a 25 year old male from United States. I'm looking for"
)
EXPECTED_OUTPUT_COMPILE = "Hello my name is Jake and I am a 20 year old student at the University of North Texas. (Go Mean Green!) I am a huge fan of the Dallas"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
device_map=cls.device_map,
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from spqr_quant import QuantizedLinear
from transformers.integrations import replace_with_spqr_linear
model_id = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_id)
quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config
quantization_config = SpQRConfig.from_dict(quantization_config)
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id, config=config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model, _ = replace_with_spqr_linear(
model,
quantization_config=quantization_config,
modules_to_not_convert=quantization_config.modules_to_not_convert,
)
nb_spqr_linear = 0
for module in model.modules():
if isinstance(module, QuantizedLinear):
nb_spqr_linear += 1
self.assertEqual(nb_linears - 1, nb_spqr_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_if_non_quantized(self):
model_id = "meta-llama/Llama-2-7b-hf"
quantization_config = SpQRConfig()
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
@unittest.skip
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_compile(self):
"""
Simple test that checks if the quantized model is working properly
"""
# Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
return new_token
# Tokenize the test input
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"]
seq_length = input_ids.shape[1]
# Setup static KV cache for generation
past_key_values = StaticCache(
config=self.quantized_model.config,
batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,
)
# Allocate token ids to be generated and copy prefix ids
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device)
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
# Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(
input_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token
with torch.no_grad():
# Compile the CUDA graph
decode_one_tokens = torch.compile(decode_one_tokens, mode="default", backend="inductor", fullgraph=True)
# Generate tokens one by one
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
)
generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1
# Check generated text
self.assertEqual(
self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_COMPILE
)