diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index bbd1879fb15..4666fe3576c 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -124,13 +124,13 @@ class MixedInt8Test(BaseMixedInt8Test): gc.collect() torch.cuda.empty_cache() - def test_get_keys_to_not_convert(self): + @unittest.skip("Un-skip once https://github.com/mosaicml/llm-foundry/issues/703 is resolved") + def test_get_keys_to_not_convert_trust_remote_code(self): r""" - Test the `get_keys_to_not_convert` function. + Test the `get_keys_to_not_convert` function with `trust_remote_code` models. """ from accelerate import init_empty_weights - from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM from transformers.integrations.bitsandbytes import get_keys_to_not_convert model_id = "mosaicml/mpt-7b" @@ -142,7 +142,17 @@ class MixedInt8Test(BaseMixedInt8Test): config, trust_remote_code=True, code_revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7" ) self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"]) - # without trust_remote_code + + def test_get_keys_to_not_convert(self): + r""" + Test the `get_keys_to_not_convert` function. + """ + from accelerate import init_empty_weights + + from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM + from transformers.integrations.bitsandbytes import get_keys_to_not_convert + + model_id = "mosaicml/mpt-7b" config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7") with init_empty_weights(): model = MptForCausalLM(config)