mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Include output embedding as well with include_embedding
flag (#37935)
* Include output embedding as well with `include_embedding` flag Summary: att Test Plan: python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding Reviewers: Subscribers: Tasks: Tags: * format * rename include_embedding to include_input_output_embeddings --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
34c1e29cdd
commit
44fa04ae8d
@ -185,10 +185,14 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
if self.quantization_config.include_embedding:
|
||||
if self.quantization_config.include_input_output_embeddings:
|
||||
input_emb = model.get_input_embeddings()
|
||||
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
|
||||
self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names]
|
||||
output_emb = model.get_output_embeddings()
|
||||
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
||||
self.modules_to_not_convert = [
|
||||
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
|
||||
]
|
||||
return
|
||||
|
||||
def check_quantized_param(
|
||||
@ -213,7 +217,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
# we only quantize the weight of nn.Linear and nn.Embedding
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
_QUANTIZABLE = [torch.nn.Linear]
|
||||
if self.quantization_config.include_embedding:
|
||||
if self.quantization_config.include_input_output_embeddings:
|
||||
_QUANTIZABLE.append(torch.nn.Embedding)
|
||||
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
|
||||
|
||||
|
@ -1554,7 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
|
||||
modules_to_not_convert: Optional[List]
|
||||
quant_type_kwargs: Dict[str, Any]
|
||||
include_embedding: bool
|
||||
include_input_output_embeddings: bool
|
||||
untie_embedding_weights: bool
|
||||
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
@ -1617,7 +1617,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self,
|
||||
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
include_embedding: bool = False,
|
||||
include_input_output_embeddings: bool = False,
|
||||
untie_embedding_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@ -1625,7 +1625,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.quant_type = quant_type
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
|
||||
self.include_embedding = include_embedding
|
||||
self.include_input_output_embeddings = include_input_output_embeddings
|
||||
self.untie_embedding_weights = untie_embedding_weights
|
||||
self.post_init()
|
||||
|
||||
|
@ -201,7 +201,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.11.0")
|
||||
def test_include_embedding(self):
|
||||
def test_include_input_output_embeddings(self):
|
||||
weight_dtype = torch.int8
|
||||
granularity = PerAxis(0)
|
||||
mapping_type = MappingType.ASYMMETRIC
|
||||
@ -210,9 +210,11 @@ class TorchAoTest(unittest.TestCase):
|
||||
granularity=granularity,
|
||||
mapping_type=mapping_type,
|
||||
)
|
||||
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
|
||||
# need set `include_embedding` to True
|
||||
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
|
||||
config = AOPerModuleConfig(
|
||||
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
|
||||
)
|
||||
# need set `include_input_output_embeddings` to True
|
||||
quant_config = TorchAoConfig(quant_type=config, include_input_output_embeddings=True)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device,
|
||||
@ -220,6 +222,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
)
|
||||
# making sure embedding is quantized
|
||||
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
|
||||
self.assertTrue(isinstance(quantized_model.lm_head.weight, AffineQuantizedTensor))
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
Loading…
Reference in New Issue
Block a user