mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Break weight tying when quantizing input embedding (#37905)
Summary: Currently when we try to quantize input_embedding for some models, the output embedding (lm_head) will also be quantized the same way, since they are tied, and this may not be what we want. To break the tie, we added the option to allow people to 1. load unquantized weight 2. tie weights 3. quantize so that the tie will be broken Test Plan: ``` from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, TorchAoConfig, ) from torchao.quantization.quant_api import ( IntxWeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, AOPerModuleConfig ) from torchao.quantization.granularity import PerGroup, PerAxis import torch model_id = "microsoft/Phi-4-mini-instruct" embedding_config = IntxWeightOnlyConfig( weight_dtype=torch.int8, granularity=PerAxis(0), ) linear_config = Int8DynamicActivationIntxWeightConfig( weight_dtype=torch.int4, weight_granularity=PerGroup(32), weight_scale_dtype=torch.bfloat16, ) quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True) quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) print(quantized_model) print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight) print("lm head weight:", quantized_model.lm_head.weight) from transformers.modeling_utils import find_tied_parameters print(find_tied_parameters(quantized_model)) ``` Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
8a0a508f2b
commit
fa3c3f9cab
@ -247,6 +247,16 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(
|
||||
param_value, requires_grad=param_value.requires_grad
|
||||
).to(device=target_device)
|
||||
# if we are quantizing tied parameters, to avoid tying the quantized weights
|
||||
# the correct order to do it is
|
||||
# 1. load the weight to model
|
||||
# 2. run tie_weights to populate the weights
|
||||
# 3. quantize
|
||||
input_embed = model.get_input_embeddings()
|
||||
if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
|
||||
model.tie_weights()
|
||||
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
|
||||
|
||||
# handle AOPerModuleConfig, introduced in torchao 0.11.0+
|
||||
if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
|
||||
from torchao.quantization import AOPerModuleConfig
|
||||
|
@ -1555,6 +1555,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
modules_to_not_convert: Optional[List]
|
||||
quant_type_kwargs: Dict[str, Any]
|
||||
include_embedding: bool
|
||||
untie_embedding_weights: bool
|
||||
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
@ -1569,6 +1570,9 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
inlcude_embedding (`bool`, default to `False`):
|
||||
Whether to include embedding in quantization or not, input embedding will be removed from
|
||||
the module_not_to_convert list as well if this flag is set.
|
||||
untie_embedding_weights (`bool`, default to `False`):
|
||||
Whether to untie the weights when we are quantizing input embedding weights that is tied
|
||||
to other weights.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
|
||||
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
|
||||
@ -1614,6 +1618,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
include_embedding: bool = False,
|
||||
untie_embedding_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.TORCHAO
|
||||
@ -1621,6 +1626,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
|
||||
self.include_embedding = include_embedding
|
||||
self.untie_embedding_weights = untie_embedding_weights
|
||||
self.post_init()
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user