Added Support for Custom Quantization (#35915)

* Added Support for Custom Quantization

* Update code

* code reformatted

* Updated Changes

* Updated Changes

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Parteek 2025-02-18 20:44:19 +05:30 committed by GitHub
parent 07182b2e10
commit 8eaae6bee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 116 additions and 3 deletions

View File

@ -0,0 +1,78 @@
import json
from typing import Any, Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer
from transformers.utils.quantization_config import QuantizationConfigMixin
@register_quantization_config("custom")
class CustomConfig(QuantizationConfigMixin):
def __init__(self):
self.quant_method = "custom"
self.bits = 8
def to_dict(self) -> Dict[str, Any]:
output = {
"num_bits": self.bits,
}
return output
def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
def to_diff_dict(self) -> Dict[str, Any]:
config_dict = self.to_dict()
default_config_dict = CustomConfig().to_dict()
serializable_config_dict = {}
for key, value in config_dict.items():
if value != default_config_dict[key]:
serializable_config_dict[key] = value
return serializable_config_dict
@register_quantizer("custom")
class CustomQuantizer(HfQuantizer):
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
self.scale_map = {}
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.torch_dtype = kwargs.get("torch_dtype", torch.float32)
def _process_model_before_weight_loading(self, model, **kwargs):
return True
def _process_model_after_weight_loading(self, model, **kwargs):
return True
def is_serializable(self) -> bool:
return True
def is_trainable(self) -> bool:
return False
model_8bit = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
input_text = "once there is"
inputs = tokenizer(input_text, return_tensors="pt")
output = model_8bit.generate(
**inputs,
max_length=100,
num_return_sequences=1,
no_repeat_ngram_size=2,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

View File

@ -3706,8 +3706,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map = hf_quantizer.update_device_map(device_map)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
else:
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True

View File

@ -11,5 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto import AutoHfQuantizer, AutoQuantizationConfig
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
from .base import HfQuantizer

View File

@ -35,6 +35,7 @@ from ..utils.quantization_config import (
TorchAoConfig,
VptqConfig,
)
from .base import HfQuantizer
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bitnet import BitNetHfQuantizer
@ -226,3 +227,35 @@ class AutoHfQuantizer:
)
return False
return True
def register_quantization_config(method: str):
"""Register a custom quantization configuration."""
def register_config_fn(cls):
if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
raise ValueError(f"Config '{method}' already registered")
if not issubclass(cls, QuantizationConfigMixin):
raise ValueError("Config must extend QuantizationConfigMixin")
AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
return cls
return register_config_fn
def register_quantizer(name: str):
"""Register a custom quantizer."""
def register_quantizer_fn(cls):
if name in AUTO_QUANTIZER_MAPPING:
raise ValueError(f"Quantizer '{name}' already registered")
if not issubclass(cls, HfQuantizer):
raise ValueError("Quantizer must extend HfQuantizer")
AUTO_QUANTIZER_MAPPING[name] = cls
return cls
return register_quantizer_fn