import json from typing import Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer from transformers.utils.quantization_config import QuantizationConfigMixin @register_quantization_config("custom") class CustomConfig(QuantizationConfigMixin): def __init__(self): self.quant_method = "custom" self.bits = 8 def to_dict(self) -> dict[str, Any]: output = { "num_bits": self.bits, } return output def __repr__(self): config_dict = self.to_dict() return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" def to_diff_dict(self) -> dict[str, Any]: config_dict = self.to_dict() default_config_dict = CustomConfig().to_dict() serializable_config_dict = {} for key, value in config_dict.items(): if value != default_config_dict[key]: serializable_config_dict[key] = value return serializable_config_dict @register_quantizer("custom") class CustomQuantizer(HfQuantizer): def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) self.quantization_config = quantization_config self.scale_map = {} self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") self.torch_dtype = kwargs.get("torch_dtype", torch.float32) def _process_model_before_weight_loading(self, model, **kwargs): return True def _process_model_after_weight_loading(self, model, **kwargs): return True def is_serializable(self) -> bool: return True def is_trainable(self) -> bool: return False model_8bit = AutoModelForCausalLM.from_pretrained( "facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto" ) tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") input_text = "once there is" inputs = tokenizer(input_text, return_tensors="pt") output = model_8bit.generate( **inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(generated_text)