mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix get_keys_to_not_convert() to return correct modules for full precision inference (#25105)
* add test for `get_keys_to_not_convert` * add minimum patch to keep mpt lm_head from 8bit quantization * add reivsion to
This commit is contained in:
parent
f6f567d0be
commit
2230d149f0
@ -265,17 +265,16 @@ def get_keys_to_not_convert(model):
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# Check if it is a base model
|
||||
is_base_model = not hasattr(model, model.base_model_prefix)
|
||||
# If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
|
||||
if not has_tied_params:
|
||||
output_emb = model.get_output_embeddings()
|
||||
if output_emb is not None:
|
||||
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
||||
return list_last_module
|
||||
|
||||
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||
if (not has_tied_params) and is_base_model:
|
||||
return []
|
||||
|
||||
# otherwise they have an attached head
|
||||
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
|
||||
list_modules = list(model.named_parameters())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
@ -124,6 +124,53 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_get_keys_to_not_convert(self):
|
||||
r"""
|
||||
Test the `get_keys_to_not_convert` function.
|
||||
"""
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
|
||||
from transformers.utils.bitsandbytes import get_keys_to_not_convert
|
||||
|
||||
model_id = "mosaicml/mpt-7b"
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, trust_remote_code=True, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
|
||||
)
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
|
||||
# without trust_remote_code
|
||||
config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7")
|
||||
with init_empty_weights():
|
||||
model = MptForCausalLM(config)
|
||||
# The order of the keys does not matter, so we sort them before comparing, same for the other tests.
|
||||
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "transformer.wte"].sort())
|
||||
|
||||
model_id = "Salesforce/blip2-opt-2.7b"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="1ef7f63a8f0a144c13fdca8103eb7b4691c74cec")
|
||||
with init_empty_weights():
|
||||
model = Blip2ForConditionalGeneration(config)
|
||||
self.assertEqual(
|
||||
get_keys_to_not_convert(model).sort(),
|
||||
["language_model.lm_head", "language_model.model.decoder.embed_tokens"].sort(),
|
||||
)
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "model.decoder.embed_tokens"].sort())
|
||||
|
||||
model_id = "roberta-large"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="716877d372b884cad6d419d828bac6c85b3b18d9")
|
||||
with init_empty_weights():
|
||||
model = AutoModelForMaskedLM.from_config(config)
|
||||
self.assertEqual(
|
||||
get_keys_to_not_convert(model).sort(),
|
||||
["'roberta.embeddings.word_embeddings', 'lm_head', 'lm_head.decoder"].sort(),
|
||||
)
|
||||
|
||||
def test_quantization_config_json_serialization(self):
|
||||
r"""
|
||||
A simple test to check if the quantization config is correctly serialized and deserialized
|
||||
|
Loading…
Reference in New Issue
Block a user