mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Added Support for Custom Quantization (#35915)
* Added Support for Custom Quantization * Update code * code reformatted * Updated Changes * Updated Changes --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
07182b2e10
commit
8eaae6bee9
78
examples/quantization/custom_quantization.py
Normal file
78
examples/quantization/custom_quantization.py
Normal file
@ -0,0 +1,78 @@
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer
|
||||
from transformers.utils.quantization_config import QuantizationConfigMixin
|
||||
|
||||
|
||||
@register_quantization_config("custom")
|
||||
class CustomConfig(QuantizationConfigMixin):
|
||||
def __init__(self):
|
||||
self.quant_method = "custom"
|
||||
self.bits = 8
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
output = {
|
||||
"num_bits": self.bits,
|
||||
}
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
def to_diff_dict(self) -> Dict[str, Any]:
|
||||
config_dict = self.to_dict()
|
||||
|
||||
default_config_dict = CustomConfig().to_dict()
|
||||
|
||||
serializable_config_dict = {}
|
||||
|
||||
for key, value in config_dict.items():
|
||||
if value != default_config_dict[key]:
|
||||
serializable_config_dict[key] = value
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
|
||||
@register_quantizer("custom")
|
||||
class CustomQuantizer(HfQuantizer):
|
||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
self.scale_map = {}
|
||||
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
||||
|
||||
def _process_model_before_weight_loading(self, model, **kwargs):
|
||||
return True
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
return True
|
||||
|
||||
def is_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
model_8bit = AutoModelForCausalLM.from_pretrained(
|
||||
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
input_text = "once there is"
|
||||
inputs = tokenizer(input_text, return_tensors="pt")
|
||||
output = model_8bit.generate(
|
||||
**inputs,
|
||||
max_length=100,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=2,
|
||||
)
|
||||
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
|
||||
print(generated_text)
|
@ -3706,8 +3706,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_map = hf_quantizer.update_device_map(device_map)
|
||||
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
|
||||
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
||||
|
||||
else:
|
||||
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
|
||||
# Force-set to `True` for more mem efficiency
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
|
@ -11,5 +11,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .auto import AutoHfQuantizer, AutoQuantizationConfig
|
||||
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
|
||||
from .base import HfQuantizer
|
||||
|
@ -35,6 +35,7 @@ from ..utils.quantization_config import (
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
)
|
||||
from .base import HfQuantizer
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
from .quantizer_bitnet import BitNetHfQuantizer
|
||||
@ -226,3 +227,35 @@ class AutoHfQuantizer:
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def register_quantization_config(method: str):
|
||||
"""Register a custom quantization configuration."""
|
||||
|
||||
def register_config_fn(cls):
|
||||
if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
|
||||
raise ValueError(f"Config '{method}' already registered")
|
||||
|
||||
if not issubclass(cls, QuantizationConfigMixin):
|
||||
raise ValueError("Config must extend QuantizationConfigMixin")
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
|
||||
return cls
|
||||
|
||||
return register_config_fn
|
||||
|
||||
|
||||
def register_quantizer(name: str):
|
||||
"""Register a custom quantizer."""
|
||||
|
||||
def register_quantizer_fn(cls):
|
||||
if name in AUTO_QUANTIZER_MAPPING:
|
||||
raise ValueError(f"Quantizer '{name}' already registered")
|
||||
|
||||
if not issubclass(cls, HfQuantizer):
|
||||
raise ValueError("Quantizer must extend HfQuantizer")
|
||||
|
||||
AUTO_QUANTIZER_MAPPING[name] = cls
|
||||
return cls
|
||||
|
||||
return register_quantizer_fn
|
||||
|
Loading…
Reference in New Issue
Block a user