mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Bug fix gguf qwen2moe (#33940)
* fix qwen2moe tensors mapping, add unit tests * add expert tensor split logic, test refactoring * small params refactoring * add comment to tensor reshaping
This commit is contained in:
parent
56be9f1925
commit
22e102ad98
@ -82,10 +82,15 @@ GGUF_TENSOR_MAPPING = {
|
||||
"qwen2moe": {
|
||||
"token_embd": "model.embed_tokens",
|
||||
"blk": "model.layers",
|
||||
"ffn_up": "mlp.up_proj",
|
||||
"ffn_down": "mlp.down_proj",
|
||||
"ffn_gate": "mlp.gate_proj",
|
||||
"ffn_up_exps": "mlp.experts",
|
||||
"ffn_up_shexp": "mlp.shared_expert.up_proj",
|
||||
"ffn_down_exps": "mlp.experts",
|
||||
"ffn_down_shexp": "mlp.shared_expert.down_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"ffn_gate_inp.weight": "mlp.gate.weight",
|
||||
"ffn_gate_exps": "mlp.experts",
|
||||
"ffn_gate_shexp": "mlp.shared_expert.gate_proj",
|
||||
"ffn_gate_inp_shexp": "mlp.shared_expert_gate",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_q": "self_attn.q_proj",
|
||||
"attn_v": "self_attn.v_proj",
|
||||
@ -200,6 +205,8 @@ GGUF_CONFIG_MAPPING = {
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
"expert_count": "num_experts",
|
||||
"expert_used_count": "num_experts_per_tok",
|
||||
},
|
||||
"falcon": {
|
||||
"context_length": "max_position_embeddings",
|
||||
|
@ -174,6 +174,15 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
elif ".attn_k." in name:
|
||||
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
|
||||
|
||||
if architecture == "qwen2moe":
|
||||
if "_exp" in name:
|
||||
split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
|
||||
continue
|
||||
if "ffn_gate_inp_shexp" in name:
|
||||
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
|
||||
# quantized one is (2048)
|
||||
weights = np.expand_dims(weights, axis=0)
|
||||
|
||||
if architecture == "bloom" and "attn_qkv" in name:
|
||||
num_heads = parsed_parameters["config"]["n_head"]
|
||||
n_embed = parsed_parameters["config"]["hidden_size"]
|
||||
@ -230,3 +239,27 @@ def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
|
||||
|
||||
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
|
||||
return qkv_bias
|
||||
|
||||
|
||||
def split_moe_expert_tensor(
|
||||
weights: np.ndarray, parsed_parameters: dict[str, dict], name: str, tensor_key_mapping: dict
|
||||
):
|
||||
# Original merge implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
|
||||
exp_name = ""
|
||||
if "ffn_gate_exps" in name:
|
||||
exp_name = "gate_proj"
|
||||
elif "ffn_down_exps" in name:
|
||||
exp_name = "down_proj"
|
||||
elif "ffn_up_exps" in name:
|
||||
exp_name = "up_proj"
|
||||
else:
|
||||
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name in name:
|
||||
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
|
||||
w_counter = parsed_parameters["config"].get("num_experts", 60)
|
||||
for i in range(0, w_counter):
|
||||
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
|
||||
exp_weight = weights[i]
|
||||
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
|
||||
|
@ -38,7 +38,8 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
|
||||
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
|
||||
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
|
||||
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
|
||||
qwen2moe_model_id = "gdax/Qwen1.5-MoE-A2.7B_gguf"
|
||||
qwen2moe_original_model_id = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
|
||||
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
|
||||
@ -72,7 +73,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
q4_0_phi3_model_id = "Phi-3-mini-4k-instruct-q4.gguf"
|
||||
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
|
||||
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
|
||||
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
|
||||
q8_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B_Q8_0.gguf"
|
||||
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
|
||||
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
|
||||
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
|
||||
@ -80,6 +81,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf"
|
||||
fp16_falcon7b_model_id = "falcon-7b-fp16.gguf"
|
||||
q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf"
|
||||
fp16_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B.gguf"
|
||||
|
||||
example_text = "Hello"
|
||||
|
||||
@ -344,21 +346,39 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_qwen2_moe_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
|
||||
def test_qwen2moe_q8(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.qwen2moe_model_id, gguf_file=self.q8_qwen2moe_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.qwen2_moe_model_id,
|
||||
gguf_file=self.q4_0_qwen2_moe_model_id,
|
||||
device_map="auto",
|
||||
self.qwen2moe_model_id,
|
||||
gguf_file=self.q8_qwen2moe_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
text = tokenizer(self.example_text, return_tensors="pt")
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
|
||||
EXPECTED_TEXT = "Hello, I am a 20 year old male"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_qwen2moe_weights_conversion_fp16(self):
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.qwen2moe_model_id,
|
||||
gguf_file=self.fp16_qwen2moe_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
original_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.qwen2moe_original_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quantized_state_dict = quantized_model.state_dict()
|
||||
original_state_dict = original_model.state_dict()
|
||||
|
||||
for layer_name, original_params in original_state_dict.items():
|
||||
if layer_name in quantized_state_dict:
|
||||
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
|
||||
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])
|
||||
|
||||
def test_phi3_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.phi3_model_id, gguf_file=self.q4_0_phi3_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -422,7 +442,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, I just want to say that I am very"
|
||||
EXPECTED_TEXT = "Hello, I just want to say that I am just"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_bloom_weights_conversion_fp16(self):
|
||||
|
Loading…
Reference in New Issue
Block a user