From cac4a4876b5c263e51b1c0e8887f35cfb6f266b1 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:14:34 +0200 Subject: [PATCH] [Quantization] Switch to optimum-quanto (#31732) * switch to optimum-quanto rebase squach * fix import check * again * test try-except * style --- .../Dockerfile | 2 +- src/transformers/cache_utils.py | 49 ++++++++++++++----- src/transformers/generation/utils.py | 7 +-- src/transformers/integrations/quanto.py | 13 ++++- .../quantizers/quantizer_quanto.py | 40 +++++++++++---- src/transformers/testing_utils.py | 6 +-- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 14 ++++++ tests/generation/test_utils.py | 4 +- .../quanto_integration/test_quanto.py | 40 +++++++-------- 10 files changed, 121 insertions(+), 55 deletions(-) diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index 6d94dbee5aa..0617ac8cdd7 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir gguf RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl # Add quanto for quantization testing -RUN python3 -m pip install --no-cache-dir quanto +RUN python3 -m pip install --no-cache-dir optimum-quanto # Add eetq for quantization testing RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0b82b17dcde..223eda10a96 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,14 +9,15 @@ import torch from packaging import version from .configuration_utils import PretrainedConfig -from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging +from .utils import ( + is_hqq_available, + is_optimum_quanto_available, + is_quanto_available, + is_torchdynamo_compiling, + logging, +) -if is_quanto_available(): - quanto_version = version.parse(importlib.metadata.version("quanto")) - if quanto_version >= version.parse("0.2.0"): - from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4 - if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer @@ -754,12 +755,20 @@ class QuantoQuantizedCache(QuantizedCache): def __init__(self, cache_config: CacheConfig) -> None: super().__init__(cache_config) - quanto_version = version.parse(importlib.metadata.version("quanto")) - if quanto_version < version.parse("0.2.0"): - raise ImportError( - f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. " - f"Please upgrade quanto with `pip install -U quanto`" + + if is_optimum_quanto_available(): + from optimum.quanto import MaxOptimizer, qint2, qint4 + elif is_quanto_available(): + logger.warning_once( + "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`" ) + quanto_version = version.parse(importlib.metadata.version("quanto")) + if quanto_version < version.parse("0.2.0"): + raise ImportError( + f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. " + f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`" + ) + from quanto import MaxOptimizer, qint2, qint4 if self.nbits not in [2, 4]: raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") @@ -776,8 +785,22 @@ class QuantoQuantizedCache(QuantizedCache): self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization def _quantize(self, tensor, axis): - scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size) - qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint) + # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + elif is_quanto_available(): + logger.warning_once( + "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`" + ) + from quanto import AffineQuantizer + + scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size) + qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint) + return qtensor def _dequantize(self, qtensor): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 661c8b579af..06b2654248f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -42,6 +42,7 @@ from ..utils import ( ModelOutput, is_accelerate_available, is_hqq_available, + is_optimum_quanto_available, is_quanto_available, is_torchdynamo_compiling, logging, @@ -1674,10 +1675,10 @@ class GenerationMixin: ) cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - if cache_config.backend == "quanto" and not is_quanto_available(): + if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()): raise ImportError( - "You need to install `quanto` in order to use KV cache quantization with quanto backend. " - "Please install it via with `pip install quanto`" + "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " + "Please install it via with `pip install optimum-quanto`" ) elif cache_config.backend == "HQQ" and not is_hqq_available(): raise ImportError( diff --git a/src/transformers/integrations/quanto.py b/src/transformers/integrations/quanto.py index 67fe9166d33..27b32de63bf 100644 --- a/src/transformers/integrations/quanto.py +++ b/src/transformers/integrations/quanto.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_torch_available +from ..utils import is_optimum_quanto_available, is_quanto_available, is_torch_available, logging if is_torch_available(): import torch +logger = logging.get_logger(__name__) + def replace_with_quanto_layers( model, @@ -45,7 +47,14 @@ def replace_with_quanto_layers( should not be passed by the user. """ from accelerate import init_empty_weights - from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8 + + if is_optimum_quanto_available(): + from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8 + elif is_quanto_available(): + logger.warning_once( + "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`" + ) + from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8 w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2} a_mapping = {None: None, "float8": qfloat8, "int8": qint8} diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py index ae113f714ac..0aacc18d2a1 100644 --- a/src/transformers/quantizers/quantizer_quanto.py +++ b/src/transformers/quantizers/quantizer_quanto.py @@ -23,7 +23,13 @@ from .quantizers_utils import get_module_from_name if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging +from ..utils import ( + is_accelerate_available, + is_optimum_quanto_available, + is_quanto_available, + is_torch_available, + logging, +) from ..utils.quantization_config import QuantoConfig @@ -57,11 +63,13 @@ class QuantoHfQuantizer(HfQuantizer): ) def validate_environment(self, *args, **kwargs): - if not is_quanto_available(): - raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)") + if not (is_optimum_quanto_available() or is_quanto_available()): + raise ImportError( + "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)" + ) if not is_accelerate_available(): raise ImportError( - "Loading a quanto quantized model requires accelerate library (`pip install accelerate`)" + "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)" ) def update_device_map(self, device_map): @@ -81,11 +89,17 @@ class QuantoHfQuantizer(HfQuantizer): return torch_dtype def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - import quanto + if is_optimum_quanto_available(): + from optimum.quanto import QModuleMixin + elif is_quanto_available(): + logger.warning_once( + "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`" + ) + from quanto import QModuleMixin not_missing_keys = [] for name, module in model.named_modules(): - if isinstance(module, quanto.QModuleMixin): + if isinstance(module, QModuleMixin): for missing in missing_keys: if ( (name in missing or name in f"{prefix}.{missing}") @@ -106,7 +120,13 @@ class QuantoHfQuantizer(HfQuantizer): """ Check if a parameter needs to be quantized. """ - import quanto + if is_optimum_quanto_available(): + from optimum.quanto import QModuleMixin + elif is_quanto_available(): + logger.warning_once( + "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`" + ) + from quanto import QModuleMixin device_map = kwargs.get("device_map", None) param_device = kwargs.get("param_device", None) @@ -119,7 +139,7 @@ class QuantoHfQuantizer(HfQuantizer): module, tensor_name = get_module_from_name(model, param_name) # We only quantize the weights and the bias is not quantized. - if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name: + if isinstance(module, QModuleMixin) and "weight" in tensor_name: # if the weights are quantized, don't need to recreate it again with `create_quantized_param` return not module.frozen else: @@ -162,7 +182,7 @@ class QuantoHfQuantizer(HfQuantizer): return target_dtype else: raise ValueError( - "You are using `device_map='auto'` on a quanto quantized model. To automatically compute" + "You are using `device_map='auto'` on an optimum-quanto quantized model. To automatically compute" " the appropriate device map, you should upgrade your `accelerate` library," "`pip install --upgrade accelerate` or install it from source." ) @@ -193,7 +213,7 @@ class QuantoHfQuantizer(HfQuantizer): @property def is_trainable(self, model: Optional["PreTrainedModel"] = None): - return False + return True def is_serializable(self, safe_serialization=None): return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 4986de42e0f..8eda45bd40e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -94,6 +94,7 @@ from .utils import ( is_nltk_available, is_onnx_available, is_optimum_available, + is_optimum_quanto_available, is_pandas_available, is_peft_available, is_phonemizer_available, @@ -102,7 +103,6 @@ from .utils import ( is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, - is_quanto_available, is_rjieba_available, is_sacremoses_available, is_safetensors_available, @@ -1194,11 +1194,11 @@ def require_auto_awq(test_case): return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case) -def require_quanto(test_case): +def require_optimum_quanto(test_case): """ Decorator for quanto dependency """ - return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case) + return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case) def require_compressed_tensors(test_case): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 134da3474be..3b33127be4b 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -163,6 +163,7 @@ from .import_utils import ( is_onnx_available, is_openai_available, is_optimum_available, + is_optimum_quanto_available, is_pandas_available, is_peft_available, is_phonemizer_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 2b2302e0dcf..519755489a3 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -143,6 +143,12 @@ _auto_gptq_available = _is_package_available("auto_gptq") # `importlib.metadata.version` doesn't work with `awq` _auto_awq_available = importlib.util.find_spec("awq") is not None _quanto_available = _is_package_available("quanto") +_is_optimum_quanto_available = False +try: + importlib.metadata.version("optimum_quanto") + _is_optimum_quanto_available = True +except importlib.metadata.PackageNotFoundError: + _is_optimum_quanto_available = False # For compressed_tensors, only check spec to allow compressed_tensors-nightly package _compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None _pandas_available = _is_package_available("pandas") @@ -963,9 +969,17 @@ def is_auto_awq_available(): def is_quanto_available(): + logger.warning_once( + "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`" + ) return _quanto_available +def is_optimum_quanto_available(): + # `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto` + return _is_optimum_quanto_available + + def is_compressed_tensors_available(): return _compressed_tensors_available diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index beb5fc7818f..b5feba6a300 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -29,7 +29,7 @@ from transformers.testing_utils import ( is_flaky, require_accelerate, require_auto_gptq, - require_quanto, + require_optimum_quanto, require_torch, require_torch_gpu, require_torch_multi_accelerator, @@ -1941,7 +1941,7 @@ class GenerationTesterMixin: self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers) self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape) - @require_quanto + @require_optimum_quanto @pytest.mark.generate def test_generate_with_quant_cache(self): for model_class in self.all_generative_model_classes: diff --git a/tests/quantization/quanto_integration/test_quanto.py b/tests/quantization/quanto_integration/test_quanto.py index d8f4fffb8d2..08cc48d0ccc 100644 --- a/tests/quantization/quanto_integration/test_quanto.py +++ b/tests/quantization/quanto_integration/test_quanto.py @@ -19,13 +19,13 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig from transformers.testing_utils import ( require_accelerate, - require_quanto, + require_optimum_quanto, require_read_token, require_torch_gpu, slow, torch_device, ) -from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available +from transformers.utils import is_accelerate_available, is_optimum_quanto_available, is_torch_available if is_torch_available(): @@ -36,8 +36,8 @@ if is_torch_available(): if is_accelerate_available(): from accelerate import init_empty_weights -if is_quanto_available(): - from quanto import QLayerNorm, QLinear +if is_optimum_quanto_available(): + from optimum.quanto import QLayerNorm, QLinear from transformers.integrations.quanto import replace_with_quanto_layers @@ -47,7 +47,7 @@ class QuantoConfigTest(unittest.TestCase): pass -@require_quanto +@require_optimum_quanto @require_accelerate class QuantoTestIntegration(unittest.TestCase): model_id = "facebook/opt-350m" @@ -124,7 +124,7 @@ class QuantoTestIntegration(unittest.TestCase): @slow @require_torch_gpu -@require_quanto +@require_optimum_quanto @require_accelerate class QuantoQuantizationTest(unittest.TestCase): """ @@ -187,7 +187,7 @@ class QuantoQuantizationTest(unittest.TestCase): self.check_inference_correctness(self.quantized_model, "cuda") def test_quantized_model_layers(self): - from quanto import QBitsTensor, QModuleMixin, QTensor + from optimum.quanto import QBitsTensor, QModuleMixin, QTensor """ Suite of simple test to check if the layers are quantized and are working properly @@ -256,7 +256,7 @@ class QuantoQuantizationTest(unittest.TestCase): self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device))) def test_compare_with_quanto(self): - from quanto import freeze, qint4, qint8, quantize + from optimum.quanto import freeze, qint4, qint8, quantize w_mapping = {"int8": qint8, "int4": qint4} model = AutoModelForCausalLM.from_pretrained( @@ -272,7 +272,7 @@ class QuantoQuantizationTest(unittest.TestCase): @unittest.skip def test_load_from_quanto_saved(self): - from quanto import freeze, qint4, qint8, quantize + from optimum.quanto import freeze, qint4, qint8, quantize from transformers import QuantoConfig @@ -356,7 +356,7 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest): """ We check that we have unquantized value in the cpu and in the disk """ - import quanto + from optimum.quanto import QBitsTensor, QTensor cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[ "weight" @@ -364,13 +364,11 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest): disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[ "weight" ] - self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor)) - self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor)) + self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, QTensor)) + self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QTensor)) if self.weights == "int4": - self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)) - self.assertTrue( - isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor) - ) + self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor)) + self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor)) @unittest.skip(reason="Skipping test class because serialization is not supported yet") @@ -416,18 +414,18 @@ class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest): class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest): - EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University" + EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I" weights = "int4" class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest): - EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University" + EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I" weights = "int4" @unittest.skip(reason="Skipping test class because serialization is not supported yet") class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest): - EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University" + EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I" weights = "int4" @@ -443,14 +441,14 @@ class QuantoQuantizationActivationTest(unittest.TestCase): self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception)) -@require_quanto +@require_optimum_quanto @require_torch_gpu class QuantoKVCacheQuantizationTest(unittest.TestCase): @slow @require_read_token def test_quantized_cache(self): EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", + "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory is the most", "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", ]