diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index dea74b2f7c3..8ef9a7ec96a 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -30,6 +30,13 @@ AWQ_FUSED_MAPPINGS = { "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], "use_alibi": False, }, + "mixtral": { + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "mlp": ["w1", "w3", "w2"], + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], + "use_alibi": False, + "rope_theta": 1000000.0, + }, "llama": { "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], "mlp": ["gate_proj", "up_proj", "down_proj"], @@ -353,6 +360,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na previous_device, modules_to_fuse["max_seq_len"], use_alibi=modules_to_fuse["use_alibi"], + # The default value in autoawq is set to 10000.0 + rope_theta=modules_to_fuse.get("rope_theta", 10000.0), ) fused_attention_layer.is_hf_transformers = True diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2dc5a15a899..d7eec086bf9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3587,7 +3587,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # a `modules_to_not_convert` attribute we need to manually set that attribute into the # passed `quantization_config` elif ( - quantization_config.modules_to_not_convert is None + getattr(quantization_config, "modules_to_not_convert", None) is None and "modules_to_not_convert" in config.quantization_config ): quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"] diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 6ce7fca8fc6..be10eb8f918 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase): custom_mapping_model_id = "TheBloke/Yi-34B-AWQ" custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589" + mixtral_model_name = "casperhansen/mixtral-instruct-awq" + mixtral_model_revision = "87dd4ec502dde74fb3a624835c776b000d190c3b" + multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq" multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442" @@ -265,6 +268,7 @@ class AwqFusedTest(unittest.TestCase): EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for" EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org" + EXPECTED_GENERATION_MIXTRAL = prompt + " You're on the North Pole.\n\nThe" def tearDown(self): gc.collect() @@ -300,6 +304,24 @@ class AwqFusedTest(unittest.TestCase): with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) + def test_fused_modules_to_not_convert(self): + """ + Test if fused + modules to_not_covnert work as expected + """ + model_id = "hf-internal-testing/Mixtral-tiny-AWQ" + + quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True) + model = AutoModelForCausalLM.from_pretrained( + model_id, + quantization_config=quantization_config, + low_cpu_mem_usage=True, + ).to(torch_device) + + # Check if model has been correctly fused + self._check_fused_modules(model) + # Checks if the modules_to_not_convert (here gate layer) is a Linear + self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear)) + def test_generation_fused(self): """ Test generation quality for fused models - single batch case @@ -408,3 +430,24 @@ class AwqFusedTest(unittest.TestCase): outputs = model.generate(**inputs, max_new_tokens=12) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL) + + @require_torch_multi_gpu + def test_generation_mixtral_fused(self): + """ + Text generation test for Mixtral + AWQ + fused + """ + quantization_config = AwqConfig(bits=4, fuse_max_seq_len=1024, do_fuse=True) + model = AutoModelForCausalLM.from_pretrained( + self.mixtral_model_name, + quantization_config=quantization_config, + device_map="auto", + revision=self.mixtral_model_revision, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.mixtral_model_name) + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device) + + outputs = model.generate(**inputs, max_new_tokens=12) + self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)