Break weight tying when quantizing input embedding (#37905)

Summary:
Currently when we try to quantize input_embedding for some models, the output embedding
(lm_head) will also be quantized the same way, since they are tied, and this may not be what
we want. To break the tie, we added the option to allow people to
1. load unquantized weight
2. tie weights
3. quantize

so that the tie will be broken

Test Plan:
```
from transformers import (
  AutoModelForCausalLM,
  AutoProcessor,
  AutoTokenizer,
  TorchAoConfig,
)
from torchao.quantization.quant_api import (
    IntxWeightOnlyConfig,
    Int8DynamicActivationIntxWeightConfig,
    AOPerModuleConfig
)
from torchao.quantization.granularity import PerGroup, PerAxis
import torch

model_id = "microsoft/Phi-4-mini-instruct"

embedding_config = IntxWeightOnlyConfig(
    weight_dtype=torch.int8,
    granularity=PerAxis(0),
)
linear_config = Int8DynamicActivationIntxWeightConfig(
    weight_dtype=torch.int4,
    weight_granularity=PerGroup(32),
    weight_scale_dtype=torch.bfloat16,
)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

print(quantized_model)
print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight)
print("lm head weight:", quantized_model.lm_head.weight)
from transformers.modeling_utils import find_tied_parameters
print(find_tied_parameters(quantized_model))
```
Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang 2025-05-02 01:53:23 -07:00 committed by GitHub
parent 8a0a508f2b
commit fa3c3f9cab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 0 deletions

View File

@ -247,6 +247,16 @@ class TorchAoHfQuantizer(HfQuantizer):
module._parameters[tensor_name] = torch.nn.Parameter(
param_value, requires_grad=param_value.requires_grad
).to(device=target_device)
# if we are quantizing tied parameters, to avoid tying the quantized weights
# the correct order to do it is
# 1. load the weight to model
# 2. run tie_weights to populate the weights
# 3. quantize
input_embed = model.get_input_embeddings()
if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
model.tie_weights()
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
# handle AOPerModuleConfig, introduced in torchao 0.11.0+
if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
from torchao.quantization import AOPerModuleConfig

View File

@ -1555,6 +1555,7 @@ class TorchAoConfig(QuantizationConfigMixin):
modules_to_not_convert: Optional[List]
quant_type_kwargs: Dict[str, Any]
include_embedding: bool
untie_embedding_weights: bool
"""This is a config class for torchao quantization/sparsity techniques.
@ -1569,6 +1570,9 @@ class TorchAoConfig(QuantizationConfigMixin):
inlcude_embedding (`bool`, default to `False`):
Whether to include embedding in quantization or not, input embedding will be removed from
the module_not_to_convert list as well if this flag is set.
untie_embedding_weights (`bool`, default to `False`):
Whether to untie the weights when we are quantizing input embedding weights that is tied
to other weights.
kwargs (`Dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
@ -1614,6 +1618,7 @@ class TorchAoConfig(QuantizationConfigMixin):
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List] = None,
include_embedding: bool = False,
untie_embedding_weights: bool = False,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO
@ -1621,6 +1626,7 @@ class TorchAoConfig(QuantizationConfigMixin):
self.modules_to_not_convert = modules_to_not_convert
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
self.include_embedding = include_embedding
self.untie_embedding_weights = untie_embedding_weights
self.post_init()
@staticmethod