From 8eaae6bee973e2068b5ed5622a708d99ce3dd862 Mon Sep 17 00:00:00 2001 From: Parteek Date: Tue, 18 Feb 2025 20:44:19 +0530 Subject: [PATCH] Added Support for Custom Quantization (#35915) * Added Support for Custom Quantization * Update code * code reformatted * Updated Changes * Updated Changes --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- examples/quantization/custom_quantization.py | 78 ++++++++++++++++++++ src/transformers/modeling_utils.py | 6 +- src/transformers/quantizers/__init__.py | 2 +- src/transformers/quantizers/auto.py | 33 +++++++++ 4 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 examples/quantization/custom_quantization.py diff --git a/examples/quantization/custom_quantization.py b/examples/quantization/custom_quantization.py new file mode 100644 index 00000000000..16b31cd8ebe --- /dev/null +++ b/examples/quantization/custom_quantization.py @@ -0,0 +1,78 @@ +import json +from typing import Any, Dict + +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer +from transformers.utils.quantization_config import QuantizationConfigMixin + + +@register_quantization_config("custom") +class CustomConfig(QuantizationConfigMixin): + def __init__(self): + self.quant_method = "custom" + self.bits = 8 + + def to_dict(self) -> Dict[str, Any]: + output = { + "num_bits": self.bits, + } + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + config_dict = self.to_dict() + + default_config_dict = CustomConfig().to_dict() + + serializable_config_dict = {} + + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +@register_quantizer("custom") +class CustomQuantizer(HfQuantizer): + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + self.scale_map = {} + self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") + self.torch_dtype = kwargs.get("torch_dtype", torch.float32) + + def _process_model_before_weight_loading(self, model, **kwargs): + return True + + def _process_model_after_weight_loading(self, model, **kwargs): + return True + + def is_serializable(self) -> bool: + return True + + def is_trainable(self) -> bool: + return False + + +model_8bit = AutoModelForCausalLM.from_pretrained( + "facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto" +) + +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") +input_text = "once there is" +inputs = tokenizer(input_text, return_tensors="pt") +output = model_8bit.generate( + **inputs, + max_length=100, + num_return_sequences=1, + no_repeat_ngram_size=2, +) +generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + +print(generated_text) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b75151992c5..5bc952e7bd8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3706,8 +3706,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix device_map = hf_quantizer.update_device_map(device_map) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` - user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value - + if hasattr(hf_quantizer.quantization_config.quant_method, "value"): + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + else: + user_agent["quant"] = hf_quantizer.quantization_config.quant_method # Force-set to `True` for more mem efficiency if low_cpu_mem_usage is None: low_cpu_mem_usage = True diff --git a/src/transformers/quantizers/__init__.py b/src/transformers/quantizers/__init__.py index 3409af4cd78..96c8d4fa504 100755 --- a/src/transformers/quantizers/__init__.py +++ b/src/transformers/quantizers/__init__.py @@ -11,5 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .auto import AutoHfQuantizer, AutoQuantizationConfig +from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer from .base import HfQuantizer diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index ee7c832b1de..64634f98a44 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -35,6 +35,7 @@ from ..utils.quantization_config import ( TorchAoConfig, VptqConfig, ) +from .base import HfQuantizer from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer from .quantizer_bitnet import BitNetHfQuantizer @@ -226,3 +227,35 @@ class AutoHfQuantizer: ) return False return True + + +def register_quantization_config(method: str): + """Register a custom quantization configuration.""" + + def register_config_fn(cls): + if method in AUTO_QUANTIZATION_CONFIG_MAPPING: + raise ValueError(f"Config '{method}' already registered") + + if not issubclass(cls, QuantizationConfigMixin): + raise ValueError("Config must extend QuantizationConfigMixin") + + AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls + return cls + + return register_config_fn + + +def register_quantizer(name: str): + """Register a custom quantizer.""" + + def register_quantizer_fn(cls): + if name in AUTO_QUANTIZER_MAPPING: + raise ValueError(f"Quantizer '{name}' already registered") + + if not issubclass(cls, HfQuantizer): + raise ValueError("Quantizer must extend HfQuantizer") + + AUTO_QUANTIZER_MAPPING[name] = cls + return cls + + return register_quantizer_fn