mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[Quantization] Switch to optimum-quanto (#31732)
* switch to optimum-quanto rebase squach * fix import check * again * test try-except * style
This commit is contained in:
parent
b7474f211d
commit
cac4a4876b
@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir gguf
|
|||||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl
|
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl
|
||||||
|
|
||||||
# Add quanto for quantization testing
|
# Add quanto for quantization testing
|
||||||
RUN python3 -m pip install --no-cache-dir quanto
|
RUN python3 -m pip install --no-cache-dir optimum-quanto
|
||||||
|
|
||||||
# Add eetq for quantization testing
|
# Add eetq for quantization testing
|
||||||
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git
|
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git
|
||||||
|
@ -9,14 +9,15 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging
|
from .utils import (
|
||||||
|
is_hqq_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
|
is_quanto_available,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_quanto_available():
|
|
||||||
quanto_version = version.parse(importlib.metadata.version("quanto"))
|
|
||||||
if quanto_version >= version.parse("0.2.0"):
|
|
||||||
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4
|
|
||||||
|
|
||||||
if is_hqq_available():
|
if is_hqq_available():
|
||||||
from hqq.core.quantize import Quantizer as HQQQuantizer
|
from hqq.core.quantize import Quantizer as HQQQuantizer
|
||||||
|
|
||||||
@ -754,12 +755,20 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||||||
|
|
||||||
def __init__(self, cache_config: CacheConfig) -> None:
|
def __init__(self, cache_config: CacheConfig) -> None:
|
||||||
super().__init__(cache_config)
|
super().__init__(cache_config)
|
||||||
quanto_version = version.parse(importlib.metadata.version("quanto"))
|
|
||||||
if quanto_version < version.parse("0.2.0"):
|
if is_optimum_quanto_available():
|
||||||
raise ImportError(
|
from optimum.quanto import MaxOptimizer, qint2, qint4
|
||||||
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
|
elif is_quanto_available():
|
||||||
f"Please upgrade quanto with `pip install -U quanto`"
|
logger.warning_once(
|
||||||
|
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
|
||||||
)
|
)
|
||||||
|
quanto_version = version.parse(importlib.metadata.version("quanto"))
|
||||||
|
if quanto_version < version.parse("0.2.0"):
|
||||||
|
raise ImportError(
|
||||||
|
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
|
||||||
|
f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`"
|
||||||
|
)
|
||||||
|
from quanto import MaxOptimizer, qint2, qint4
|
||||||
|
|
||||||
if self.nbits not in [2, 4]:
|
if self.nbits not in [2, 4]:
|
||||||
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
||||||
@ -776,8 +785,22 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||||||
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
|
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
|
||||||
|
|
||||||
def _quantize(self, tensor, axis):
|
def _quantize(self, tensor, axis):
|
||||||
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
|
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
|
||||||
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
|
if is_optimum_quanto_available():
|
||||||
|
from optimum.quanto import quantize_weight
|
||||||
|
|
||||||
|
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
||||||
|
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
||||||
|
return qtensor
|
||||||
|
elif is_quanto_available():
|
||||||
|
logger.warning_once(
|
||||||
|
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
|
||||||
|
)
|
||||||
|
from quanto import AffineQuantizer
|
||||||
|
|
||||||
|
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
|
||||||
|
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
|
||||||
|
|
||||||
return qtensor
|
return qtensor
|
||||||
|
|
||||||
def _dequantize(self, qtensor):
|
def _dequantize(self, qtensor):
|
||||||
|
@ -42,6 +42,7 @@ from ..utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_hqq_available,
|
is_hqq_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
is_quanto_available,
|
is_quanto_available,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
logging,
|
logging,
|
||||||
@ -1674,10 +1675,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
||||||
|
|
||||||
if cache_config.backend == "quanto" and not is_quanto_available():
|
if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
|
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
|
||||||
"Please install it via with `pip install quanto`"
|
"Please install it via with `pip install optimum-quanto`"
|
||||||
)
|
)
|
||||||
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
@ -12,12 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ..utils import is_torch_available
|
from ..utils import is_optimum_quanto_available, is_quanto_available, is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_with_quanto_layers(
|
def replace_with_quanto_layers(
|
||||||
model,
|
model,
|
||||||
@ -45,7 +47,14 @@ def replace_with_quanto_layers(
|
|||||||
should not be passed by the user.
|
should not be passed by the user.
|
||||||
"""
|
"""
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
|
|
||||||
|
if is_optimum_quanto_available():
|
||||||
|
from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
|
||||||
|
elif is_quanto_available():
|
||||||
|
logger.warning_once(
|
||||||
|
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
|
||||||
|
)
|
||||||
|
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
|
||||||
|
|
||||||
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
|
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
|
||||||
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
|
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
|
||||||
|
@ -23,7 +23,13 @@ from .quantizers_utils import get_module_from_name
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging
|
from ..utils import (
|
||||||
|
is_accelerate_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
|
is_quanto_available,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from ..utils.quantization_config import QuantoConfig
|
from ..utils.quantization_config import QuantoConfig
|
||||||
|
|
||||||
|
|
||||||
@ -57,11 +63,13 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def validate_environment(self, *args, **kwargs):
|
def validate_environment(self, *args, **kwargs):
|
||||||
if not is_quanto_available():
|
if not (is_optimum_quanto_available() or is_quanto_available()):
|
||||||
raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)")
|
raise ImportError(
|
||||||
|
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||||
|
)
|
||||||
if not is_accelerate_available():
|
if not is_accelerate_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Loading a quanto quantized model requires accelerate library (`pip install accelerate`)"
|
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_device_map(self, device_map):
|
def update_device_map(self, device_map):
|
||||||
@ -81,11 +89,17 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||||
import quanto
|
if is_optimum_quanto_available():
|
||||||
|
from optimum.quanto import QModuleMixin
|
||||||
|
elif is_quanto_available():
|
||||||
|
logger.warning_once(
|
||||||
|
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
|
||||||
|
)
|
||||||
|
from quanto import QModuleMixin
|
||||||
|
|
||||||
not_missing_keys = []
|
not_missing_keys = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, quanto.QModuleMixin):
|
if isinstance(module, QModuleMixin):
|
||||||
for missing in missing_keys:
|
for missing in missing_keys:
|
||||||
if (
|
if (
|
||||||
(name in missing or name in f"{prefix}.{missing}")
|
(name in missing or name in f"{prefix}.{missing}")
|
||||||
@ -106,7 +120,13 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
"""
|
"""
|
||||||
Check if a parameter needs to be quantized.
|
Check if a parameter needs to be quantized.
|
||||||
"""
|
"""
|
||||||
import quanto
|
if is_optimum_quanto_available():
|
||||||
|
from optimum.quanto import QModuleMixin
|
||||||
|
elif is_quanto_available():
|
||||||
|
logger.warning_once(
|
||||||
|
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
|
||||||
|
)
|
||||||
|
from quanto import QModuleMixin
|
||||||
|
|
||||||
device_map = kwargs.get("device_map", None)
|
device_map = kwargs.get("device_map", None)
|
||||||
param_device = kwargs.get("param_device", None)
|
param_device = kwargs.get("param_device", None)
|
||||||
@ -119,7 +139,7 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
|
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
# We only quantize the weights and the bias is not quantized.
|
# We only quantize the weights and the bias is not quantized.
|
||||||
if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name:
|
if isinstance(module, QModuleMixin) and "weight" in tensor_name:
|
||||||
# if the weights are quantized, don't need to recreate it again with `create_quantized_param`
|
# if the weights are quantized, don't need to recreate it again with `create_quantized_param`
|
||||||
return not module.frozen
|
return not module.frozen
|
||||||
else:
|
else:
|
||||||
@ -162,7 +182,7 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
return target_dtype
|
return target_dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are using `device_map='auto'` on a quanto quantized model. To automatically compute"
|
"You are using `device_map='auto'` on an optimum-quanto quantized model. To automatically compute"
|
||||||
" the appropriate device map, you should upgrade your `accelerate` library,"
|
" the appropriate device map, you should upgrade your `accelerate` library,"
|
||||||
"`pip install --upgrade accelerate` or install it from source."
|
"`pip install --upgrade accelerate` or install it from source."
|
||||||
)
|
)
|
||||||
@ -193,7 +213,7 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||||
return False
|
return True
|
||||||
|
|
||||||
def is_serializable(self, safe_serialization=None):
|
def is_serializable(self, safe_serialization=None):
|
||||||
return False
|
return False
|
||||||
|
@ -94,6 +94,7 @@ from .utils import (
|
|||||||
is_nltk_available,
|
is_nltk_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_optimum_available,
|
is_optimum_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
is_phonemizer_available,
|
is_phonemizer_available,
|
||||||
@ -102,7 +103,6 @@ from .utils import (
|
|||||||
is_pytesseract_available,
|
is_pytesseract_available,
|
||||||
is_pytest_available,
|
is_pytest_available,
|
||||||
is_pytorch_quantization_available,
|
is_pytorch_quantization_available,
|
||||||
is_quanto_available,
|
|
||||||
is_rjieba_available,
|
is_rjieba_available,
|
||||||
is_sacremoses_available,
|
is_sacremoses_available,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
@ -1194,11 +1194,11 @@ def require_auto_awq(test_case):
|
|||||||
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
|
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_quanto(test_case):
|
def require_optimum_quanto(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator for quanto dependency
|
Decorator for quanto dependency
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
|
return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_compressed_tensors(test_case):
|
def require_compressed_tensors(test_case):
|
||||||
|
@ -163,6 +163,7 @@ from .import_utils import (
|
|||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_openai_available,
|
is_openai_available,
|
||||||
is_optimum_available,
|
is_optimum_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
is_phonemizer_available,
|
is_phonemizer_available,
|
||||||
|
@ -143,6 +143,12 @@ _auto_gptq_available = _is_package_available("auto_gptq")
|
|||||||
# `importlib.metadata.version` doesn't work with `awq`
|
# `importlib.metadata.version` doesn't work with `awq`
|
||||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||||
_quanto_available = _is_package_available("quanto")
|
_quanto_available = _is_package_available("quanto")
|
||||||
|
_is_optimum_quanto_available = False
|
||||||
|
try:
|
||||||
|
importlib.metadata.version("optimum_quanto")
|
||||||
|
_is_optimum_quanto_available = True
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
_is_optimum_quanto_available = False
|
||||||
# For compressed_tensors, only check spec to allow compressed_tensors-nightly package
|
# For compressed_tensors, only check spec to allow compressed_tensors-nightly package
|
||||||
_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None
|
_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None
|
||||||
_pandas_available = _is_package_available("pandas")
|
_pandas_available = _is_package_available("pandas")
|
||||||
@ -963,9 +969,17 @@ def is_auto_awq_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_quanto_available():
|
def is_quanto_available():
|
||||||
|
logger.warning_once(
|
||||||
|
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
|
||||||
|
)
|
||||||
return _quanto_available
|
return _quanto_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_optimum_quanto_available():
|
||||||
|
# `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto`
|
||||||
|
return _is_optimum_quanto_available
|
||||||
|
|
||||||
|
|
||||||
def is_compressed_tensors_available():
|
def is_compressed_tensors_available():
|
||||||
return _compressed_tensors_available
|
return _compressed_tensors_available
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ from transformers.testing_utils import (
|
|||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_auto_gptq,
|
require_auto_gptq,
|
||||||
require_quanto,
|
require_optimum_quanto,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
@ -1941,7 +1941,7 @@ class GenerationTesterMixin:
|
|||||||
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
||||||
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
||||||
|
|
||||||
@require_quanto
|
@require_optimum_quanto
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_with_quant_cache(self):
|
def test_generate_with_quant_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
|
@ -19,13 +19,13 @@ import unittest
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_quanto,
|
require_optimum_quanto,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
|
from transformers.utils import is_accelerate_available, is_optimum_quanto_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -36,8 +36,8 @@ if is_torch_available():
|
|||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
if is_quanto_available():
|
if is_optimum_quanto_available():
|
||||||
from quanto import QLayerNorm, QLinear
|
from optimum.quanto import QLayerNorm, QLinear
|
||||||
|
|
||||||
from transformers.integrations.quanto import replace_with_quanto_layers
|
from transformers.integrations.quanto import replace_with_quanto_layers
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ class QuantoConfigTest(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_quanto
|
@require_optimum_quanto
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
class QuantoTestIntegration(unittest.TestCase):
|
class QuantoTestIntegration(unittest.TestCase):
|
||||||
model_id = "facebook/opt-350m"
|
model_id = "facebook/opt-350m"
|
||||||
@ -124,7 +124,7 @@ class QuantoTestIntegration(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_quanto
|
@require_optimum_quanto
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
class QuantoQuantizationTest(unittest.TestCase):
|
class QuantoQuantizationTest(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
@ -187,7 +187,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
|||||||
self.check_inference_correctness(self.quantized_model, "cuda")
|
self.check_inference_correctness(self.quantized_model, "cuda")
|
||||||
|
|
||||||
def test_quantized_model_layers(self):
|
def test_quantized_model_layers(self):
|
||||||
from quanto import QBitsTensor, QModuleMixin, QTensor
|
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Suite of simple test to check if the layers are quantized and are working properly
|
Suite of simple test to check if the layers are quantized and are working properly
|
||||||
@ -256,7 +256,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
||||||
|
|
||||||
def test_compare_with_quanto(self):
|
def test_compare_with_quanto(self):
|
||||||
from quanto import freeze, qint4, qint8, quantize
|
from optimum.quanto import freeze, qint4, qint8, quantize
|
||||||
|
|
||||||
w_mapping = {"int8": qint8, "int4": qint4}
|
w_mapping = {"int8": qint8, "int4": qint4}
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@ -272,7 +272,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@unittest.skip
|
@unittest.skip
|
||||||
def test_load_from_quanto_saved(self):
|
def test_load_from_quanto_saved(self):
|
||||||
from quanto import freeze, qint4, qint8, quantize
|
from optimum.quanto import freeze, qint4, qint8, quantize
|
||||||
|
|
||||||
from transformers import QuantoConfig
|
from transformers import QuantoConfig
|
||||||
|
|
||||||
@ -356,7 +356,7 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
|
|||||||
"""
|
"""
|
||||||
We check that we have unquantized value in the cpu and in the disk
|
We check that we have unquantized value in the cpu and in the disk
|
||||||
"""
|
"""
|
||||||
import quanto
|
from optimum.quanto import QBitsTensor, QTensor
|
||||||
|
|
||||||
cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[
|
cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[
|
||||||
"weight"
|
"weight"
|
||||||
@ -364,13 +364,11 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
|
|||||||
disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[
|
disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[
|
||||||
"weight"
|
"weight"
|
||||||
]
|
]
|
||||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor))
|
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, QTensor))
|
||||||
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor))
|
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QTensor))
|
||||||
if self.weights == "int4":
|
if self.weights == "int4":
|
||||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor))
|
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor))
|
||||||
self.assertTrue(
|
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor))
|
||||||
isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
|
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
|
||||||
@ -416,18 +414,18 @@ class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest):
|
|||||||
|
|
||||||
|
|
||||||
class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
|
class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
|
||||||
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
|
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
|
||||||
weights = "int4"
|
weights = "int4"
|
||||||
|
|
||||||
|
|
||||||
class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest):
|
class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest):
|
||||||
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
|
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
|
||||||
weights = "int4"
|
weights = "int4"
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
|
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
|
||||||
class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest):
|
class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest):
|
||||||
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
|
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
|
||||||
weights = "int4"
|
weights = "int4"
|
||||||
|
|
||||||
|
|
||||||
@ -443,14 +441,14 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
|
|||||||
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
|
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
|
||||||
|
|
||||||
|
|
||||||
@require_quanto
|
@require_optimum_quanto
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_quantized_cache(self):
|
def test_quantized_cache(self):
|
||||||
EXPECTED_TEXT_COMPLETION = [
|
EXPECTED_TEXT_COMPLETION = [
|
||||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory is the most",
|
||||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user