[tests / Quantization] Fix bnb test (#27145)

* fix bnb test

* link to GH issue
This commit is contained in:
Younes Belkada 2023-10-30 15:43:08 +01:00 committed by GitHub
parent 576994963f
commit 6b466771b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -124,13 +124,13 @@ class MixedInt8Test(BaseMixedInt8Test):
gc.collect()
torch.cuda.empty_cache()
def test_get_keys_to_not_convert(self):
@unittest.skip("Un-skip once https://github.com/mosaicml/llm-foundry/issues/703 is resolved")
def test_get_keys_to_not_convert_trust_remote_code(self):
r"""
Test the `get_keys_to_not_convert` function.
Test the `get_keys_to_not_convert` function with `trust_remote_code` models.
"""
from accelerate import init_empty_weights
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
from transformers.integrations.bitsandbytes import get_keys_to_not_convert
model_id = "mosaicml/mpt-7b"
@ -142,7 +142,17 @@ class MixedInt8Test(BaseMixedInt8Test):
config, trust_remote_code=True, code_revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
)
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
# without trust_remote_code
def test_get_keys_to_not_convert(self):
r"""
Test the `get_keys_to_not_convert` function.
"""
from accelerate import init_empty_weights
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
from transformers.integrations.bitsandbytes import get_keys_to_not_convert
model_id = "mosaicml/mpt-7b"
config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7")
with init_empty_weights():
model = MptForCausalLM(config)