diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 0bd88725855..030b0de0c0c 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -185,10 +185,14 @@ class TorchAoHfQuantizer(HfQuantizer): self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) - if self.quantization_config.include_embedding: + if self.quantization_config.include_input_output_embeddings: input_emb = model.get_input_embeddings() input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)] - self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names] + output_emb = model.get_output_embeddings() + output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)] + self.modules_to_not_convert = [ + x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names + ] return def check_quantized_param( @@ -213,7 +217,7 @@ class TorchAoHfQuantizer(HfQuantizer): # we only quantize the weight of nn.Linear and nn.Embedding module, tensor_name = get_module_from_name(model, param_name) _QUANTIZABLE = [torch.nn.Linear] - if self.quantization_config.include_embedding: + if self.quantization_config.include_input_output_embeddings: _QUANTIZABLE.append(torch.nn.Embedding) return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight") diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 5d9ab1f6f20..ec2a6c76dee 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1554,7 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin): quant_type: Union[str, "AOBaseConfig"] # noqa: F821 modules_to_not_convert: Optional[List] quant_type_kwargs: Dict[str, Any] - include_embedding: bool + include_input_output_embeddings: bool untie_embedding_weights: bool """This is a config class for torchao quantization/sparsity techniques. @@ -1617,7 +1617,7 @@ class TorchAoConfig(QuantizationConfigMixin): self, quant_type: Union[str, "AOBaseConfig"], # noqa: F821 modules_to_not_convert: Optional[List] = None, - include_embedding: bool = False, + include_input_output_embeddings: bool = False, untie_embedding_weights: bool = False, **kwargs, ): @@ -1625,7 +1625,7 @@ class TorchAoConfig(QuantizationConfigMixin): self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs) - self.include_embedding = include_embedding + self.include_input_output_embeddings = include_input_output_embeddings self.untie_embedding_weights = untie_embedding_weights self.post_init() diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 61d569a040b..8f1c15c94d6 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -201,7 +201,7 @@ class TorchAoTest(unittest.TestCase): self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) @require_torchao_version_greater_or_equal("0.11.0") - def test_include_embedding(self): + def test_include_input_output_embeddings(self): weight_dtype = torch.int8 granularity = PerAxis(0) mapping_type = MappingType.ASYMMETRIC @@ -210,9 +210,11 @@ class TorchAoTest(unittest.TestCase): granularity=granularity, mapping_type=mapping_type, ) - config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config}) - # need set `include_embedding` to True - quant_config = TorchAoConfig(quant_type=config, include_embedding=True) + config = AOPerModuleConfig( + {"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config} + ) + # need set `include_input_output_embeddings` to True + quant_config = TorchAoConfig(quant_type=config, include_input_output_embeddings=True) quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map=self.device, @@ -220,6 +222,7 @@ class TorchAoTest(unittest.TestCase): ) # making sure embedding is quantized self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(quantized_model.lm_head.weight, AffineQuantizedTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)