mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Added Support for Custom Quantization (#35915)
* Added Support for Custom Quantization * Update code * code reformatted * Updated Changes * Updated Changes --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
07182b2e10
commit
8eaae6bee9
78
examples/quantization/custom_quantization.py
Normal file
78
examples/quantization/custom_quantization.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer
|
||||||
|
from transformers.utils.quantization_config import QuantizationConfigMixin
|
||||||
|
|
||||||
|
|
||||||
|
@register_quantization_config("custom")
|
||||||
|
class CustomConfig(QuantizationConfigMixin):
|
||||||
|
def __init__(self):
|
||||||
|
self.quant_method = "custom"
|
||||||
|
self.bits = 8
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
output = {
|
||||||
|
"num_bits": self.bits,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
config_dict = self.to_dict()
|
||||||
|
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||||
|
|
||||||
|
def to_diff_dict(self) -> Dict[str, Any]:
|
||||||
|
config_dict = self.to_dict()
|
||||||
|
|
||||||
|
default_config_dict = CustomConfig().to_dict()
|
||||||
|
|
||||||
|
serializable_config_dict = {}
|
||||||
|
|
||||||
|
for key, value in config_dict.items():
|
||||||
|
if value != default_config_dict[key]:
|
||||||
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
|
return serializable_config_dict
|
||||||
|
|
||||||
|
|
||||||
|
@register_quantizer("custom")
|
||||||
|
class CustomQuantizer(HfQuantizer):
|
||||||
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||||
|
super().__init__(quantization_config, **kwargs)
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
self.scale_map = {}
|
||||||
|
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
||||||
|
|
||||||
|
def _process_model_before_weight_loading(self, model, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_trainable(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
model_8bit = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
input_text = "once there is"
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt")
|
||||||
|
output = model_8bit.generate(
|
||||||
|
**inputs,
|
||||||
|
max_length=100,
|
||||||
|
num_return_sequences=1,
|
||||||
|
no_repeat_ngram_size=2,
|
||||||
|
)
|
||||||
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
print(generated_text)
|
@ -3706,8 +3706,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
device_map = hf_quantizer.update_device_map(device_map)
|
device_map = hf_quantizer.update_device_map(device_map)
|
||||||
|
|
||||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||||
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
|
||||||
|
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
||||||
|
else:
|
||||||
|
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
|
||||||
# Force-set to `True` for more mem efficiency
|
# Force-set to `True` for more mem efficiency
|
||||||
if low_cpu_mem_usage is None:
|
if low_cpu_mem_usage is None:
|
||||||
low_cpu_mem_usage = True
|
low_cpu_mem_usage = True
|
||||||
|
@ -11,5 +11,5 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .auto import AutoHfQuantizer, AutoQuantizationConfig
|
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
|
||||||
from .base import HfQuantizer
|
from .base import HfQuantizer
|
||||||
|
@ -35,6 +35,7 @@ from ..utils.quantization_config import (
|
|||||||
TorchAoConfig,
|
TorchAoConfig,
|
||||||
VptqConfig,
|
VptqConfig,
|
||||||
)
|
)
|
||||||
|
from .base import HfQuantizer
|
||||||
from .quantizer_aqlm import AqlmHfQuantizer
|
from .quantizer_aqlm import AqlmHfQuantizer
|
||||||
from .quantizer_awq import AwqQuantizer
|
from .quantizer_awq import AwqQuantizer
|
||||||
from .quantizer_bitnet import BitNetHfQuantizer
|
from .quantizer_bitnet import BitNetHfQuantizer
|
||||||
@ -226,3 +227,35 @@ class AutoHfQuantizer:
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def register_quantization_config(method: str):
|
||||||
|
"""Register a custom quantization configuration."""
|
||||||
|
|
||||||
|
def register_config_fn(cls):
|
||||||
|
if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
|
||||||
|
raise ValueError(f"Config '{method}' already registered")
|
||||||
|
|
||||||
|
if not issubclass(cls, QuantizationConfigMixin):
|
||||||
|
raise ValueError("Config must extend QuantizationConfigMixin")
|
||||||
|
|
||||||
|
AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return register_config_fn
|
||||||
|
|
||||||
|
|
||||||
|
def register_quantizer(name: str):
|
||||||
|
"""Register a custom quantizer."""
|
||||||
|
|
||||||
|
def register_quantizer_fn(cls):
|
||||||
|
if name in AUTO_QUANTIZER_MAPPING:
|
||||||
|
raise ValueError(f"Quantizer '{name}' already registered")
|
||||||
|
|
||||||
|
if not issubclass(cls, HfQuantizer):
|
||||||
|
raise ValueError("Quantizer must extend HfQuantizer")
|
||||||
|
|
||||||
|
AUTO_QUANTIZER_MAPPING[name] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return register_quantizer_fn
|
||||||
|
Loading…
Reference in New Issue
Block a user