fix get_keys_to_not_convert() to return correct modules for full precision inference (#25105)

* add test for `get_keys_to_not_convert`

* add minimum patch to keep mpt lm_head from 8bit quantization

* add reivsion to
This commit is contained in:
YQ 2023-08-02 16:21:52 +08:00 committed by GitHub
parent f6f567d0be
commit 2230d149f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 8 deletions

View File

@ -265,17 +265,16 @@ def get_keys_to_not_convert(model):
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)
# If there is not tied weights, we want to keep the lm_headoutput_embedding) in full precision
if not has_tied_params:
output_emb = model.get_output_embeddings()
if output_emb is not None:
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
return list_last_module
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return []
# otherwise they have an attached head
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)

View File

@ -124,6 +124,53 @@ class MixedInt8Test(BaseMixedInt8Test):
gc.collect()
torch.cuda.empty_cache()
def test_get_keys_to_not_convert(self):
r"""
Test the `get_keys_to_not_convert` function.
"""
from accelerate import init_empty_weights
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
from transformers.utils.bitsandbytes import get_keys_to_not_convert
model_id = "mosaicml/mpt-7b"
config = AutoConfig.from_pretrained(
model_id, trust_remote_code=True, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
# without trust_remote_code
config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7")
with init_empty_weights():
model = MptForCausalLM(config)
# The order of the keys does not matter, so we sort them before comparing, same for the other tests.
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "transformer.wte"].sort())
model_id = "Salesforce/blip2-opt-2.7b"
config = AutoConfig.from_pretrained(model_id, revision="1ef7f63a8f0a144c13fdca8103eb7b4691c74cec")
with init_empty_weights():
model = Blip2ForConditionalGeneration(config)
self.assertEqual(
get_keys_to_not_convert(model).sort(),
["language_model.lm_head", "language_model.model.decoder.embed_tokens"].sort(),
)
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
with init_empty_weights():
model = OPTForCausalLM(config)
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "model.decoder.embed_tokens"].sort())
model_id = "roberta-large"
config = AutoConfig.from_pretrained(model_id, revision="716877d372b884cad6d419d828bac6c85b3b18d9")
with init_empty_weights():
model = AutoModelForMaskedLM.from_config(config)
self.assertEqual(
get_keys_to_not_convert(model).sort(),
["'roberta.embeddings.word_embeddings', 'lm_head', 'lm_head.decoder"].sort(),
)
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized