[Quantization] Switch to optimum-quanto (#31732)

* switch to optimum-quanto rebase squach

* fix import check

* again

* test try-except

* style
This commit is contained in:
Marc Sun 2024-10-02 15:14:34 +02:00 committed by GitHub
parent b7474f211d
commit cac4a4876b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 121 additions and 55 deletions

View File

@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir gguf
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl
# Add quanto for quantization testing
RUN python3 -m pip install --no-cache-dir quanto
RUN python3 -m pip install --no-cache-dir optimum-quanto
# Add eetq for quantization testing
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git

View File

@ -9,14 +9,15 @@ import torch
from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_quanto_available,
is_torchdynamo_compiling,
logging,
)
if is_quanto_available():
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version >= version.parse("0.2.0"):
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
@ -754,12 +755,20 @@ class QuantoQuantizedCache(QuantizedCache):
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
if is_optimum_quanto_available():
from optimum.quanto import MaxOptimizer, qint2, qint4
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Please upgrade quanto with `pip install -U quanto`"
f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`"
)
from quanto import MaxOptimizer, qint2, qint4
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
@ -776,8 +785,22 @@ class QuantoQuantizedCache(QuantizedCache):
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
def _quantize(self, tensor, axis):
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
return qtensor
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
)
from quanto import AffineQuantizer
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
return qtensor
def _dequantize(self, qtensor):

View File

@ -42,6 +42,7 @@ from ..utils import (
ModelOutput,
is_accelerate_available,
is_hqq_available,
is_optimum_quanto_available,
is_quanto_available,
is_torchdynamo_compiling,
logging,
@ -1674,10 +1675,10 @@ class GenerationMixin:
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()):
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
"Please install it via with `pip install optimum-quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(

View File

@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
from ..utils import is_optimum_quanto_available, is_quanto_available, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
def replace_with_quanto_layers(
model,
@ -45,6 +47,13 @@ def replace_with_quanto_layers(
should not be passed by the user.
"""
from accelerate import init_empty_weights
if is_optimum_quanto_available():
from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
)
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}

View File

@ -23,7 +23,13 @@ from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging
from ..utils import (
is_accelerate_available,
is_optimum_quanto_available,
is_quanto_available,
is_torch_available,
logging,
)
from ..utils.quantization_config import QuantoConfig
@ -57,11 +63,13 @@ class QuantoHfQuantizer(HfQuantizer):
)
def validate_environment(self, *args, **kwargs):
if not is_quanto_available():
raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)")
if not (is_optimum_quanto_available() or is_quanto_available()):
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
)
if not is_accelerate_available():
raise ImportError(
"Loading a quanto quantized model requires accelerate library (`pip install accelerate`)"
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
)
def update_device_map(self, device_map):
@ -81,11 +89,17 @@ class QuantoHfQuantizer(HfQuantizer):
return torch_dtype
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
import quanto
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
)
from quanto import QModuleMixin
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, quanto.QModuleMixin):
if isinstance(module, QModuleMixin):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
@ -106,7 +120,13 @@ class QuantoHfQuantizer(HfQuantizer):
"""
Check if a parameter needs to be quantized.
"""
import quanto
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
)
from quanto import QModuleMixin
device_map = kwargs.get("device_map", None)
param_device = kwargs.get("param_device", None)
@ -119,7 +139,7 @@ class QuantoHfQuantizer(HfQuantizer):
module, tensor_name = get_module_from_name(model, param_name)
# We only quantize the weights and the bias is not quantized.
if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name:
if isinstance(module, QModuleMixin) and "weight" in tensor_name:
# if the weights are quantized, don't need to recreate it again with `create_quantized_param`
return not module.frozen
else:
@ -162,7 +182,7 @@ class QuantoHfQuantizer(HfQuantizer):
return target_dtype
else:
raise ValueError(
"You are using `device_map='auto'` on a quanto quantized model. To automatically compute"
"You are using `device_map='auto'` on an optimum-quanto quantized model. To automatically compute"
" the appropriate device map, you should upgrade your `accelerate` library,"
"`pip install --upgrade accelerate` or install it from source."
)
@ -193,7 +213,7 @@ class QuantoHfQuantizer(HfQuantizer):
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
return True
def is_serializable(self, safe_serialization=None):
return False

View File

@ -94,6 +94,7 @@ from .utils import (
is_nltk_available,
is_onnx_available,
is_optimum_available,
is_optimum_quanto_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,
@ -102,7 +103,6 @@ from .utils import (
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
@ -1194,11 +1194,11 @@ def require_auto_awq(test_case):
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
def require_quanto(test_case):
def require_optimum_quanto(test_case):
"""
Decorator for quanto dependency
"""
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
def require_compressed_tensors(test_case):

View File

@ -163,6 +163,7 @@ from .import_utils import (
is_onnx_available,
is_openai_available,
is_optimum_available,
is_optimum_quanto_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,

View File

@ -143,6 +143,12 @@ _auto_gptq_available = _is_package_available("auto_gptq")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
_is_optimum_quanto_available = False
try:
importlib.metadata.version("optimum_quanto")
_is_optimum_quanto_available = True
except importlib.metadata.PackageNotFoundError:
_is_optimum_quanto_available = False
# For compressed_tensors, only check spec to allow compressed_tensors-nightly package
_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None
_pandas_available = _is_package_available("pandas")
@ -963,9 +969,17 @@ def is_auto_awq_available():
def is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
)
return _quanto_available
def is_optimum_quanto_available():
# `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto`
return _is_optimum_quanto_available
def is_compressed_tensors_available():
return _compressed_tensors_available

View File

@ -29,7 +29,7 @@ from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_auto_gptq,
require_quanto,
require_optimum_quanto,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
@ -1941,7 +1941,7 @@ class GenerationTesterMixin:
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
@require_quanto
@require_optimum_quanto
@pytest.mark.generate
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:

View File

@ -19,13 +19,13 @@ import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from transformers.testing_utils import (
require_accelerate,
require_quanto,
require_optimum_quanto,
require_read_token,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
from transformers.utils import is_accelerate_available, is_optimum_quanto_available, is_torch_available
if is_torch_available():
@ -36,8 +36,8 @@ if is_torch_available():
if is_accelerate_available():
from accelerate import init_empty_weights
if is_quanto_available():
from quanto import QLayerNorm, QLinear
if is_optimum_quanto_available():
from optimum.quanto import QLayerNorm, QLinear
from transformers.integrations.quanto import replace_with_quanto_layers
@ -47,7 +47,7 @@ class QuantoConfigTest(unittest.TestCase):
pass
@require_quanto
@require_optimum_quanto
@require_accelerate
class QuantoTestIntegration(unittest.TestCase):
model_id = "facebook/opt-350m"
@ -124,7 +124,7 @@ class QuantoTestIntegration(unittest.TestCase):
@slow
@require_torch_gpu
@require_quanto
@require_optimum_quanto
@require_accelerate
class QuantoQuantizationTest(unittest.TestCase):
"""
@ -187,7 +187,7 @@ class QuantoQuantizationTest(unittest.TestCase):
self.check_inference_correctness(self.quantized_model, "cuda")
def test_quantized_model_layers(self):
from quanto import QBitsTensor, QModuleMixin, QTensor
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
"""
Suite of simple test to check if the layers are quantized and are working properly
@ -256,7 +256,7 @@ class QuantoQuantizationTest(unittest.TestCase):
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
def test_compare_with_quanto(self):
from quanto import freeze, qint4, qint8, quantize
from optimum.quanto import freeze, qint4, qint8, quantize
w_mapping = {"int8": qint8, "int4": qint4}
model = AutoModelForCausalLM.from_pretrained(
@ -272,7 +272,7 @@ class QuantoQuantizationTest(unittest.TestCase):
@unittest.skip
def test_load_from_quanto_saved(self):
from quanto import freeze, qint4, qint8, quantize
from optimum.quanto import freeze, qint4, qint8, quantize
from transformers import QuantoConfig
@ -356,7 +356,7 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
"""
We check that we have unquantized value in the cpu and in the disk
"""
import quanto
from optimum.quanto import QBitsTensor, QTensor
cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[
"weight"
@ -364,13 +364,11 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[
"weight"
]
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor))
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor))
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, QTensor))
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QTensor))
if self.weights == "int4":
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor))
self.assertTrue(
isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)
)
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor))
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor))
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
@ -416,18 +414,18 @@ class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest):
class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
weights = "int4"
class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest):
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
weights = "int4"
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest):
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
weights = "int4"
@ -443,14 +441,14 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
@require_quanto
@require_optimum_quanto
@require_torch_gpu
class QuantoKVCacheQuantizationTest(unittest.TestCase):
@slow
@require_read_token
def test_quantized_cache(self):
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory is the most",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]