Add Qwen2Moe GGUF loading support (#33264)

* update gguf doc, config and tensor mapping

* add qwen2moe architecture support, GGUFQwen2MoeConverter and q4 unit tests

* apply code style fixes

* reformat files

* assign GGUFQwen2Converter to qwen2_moe
This commit is contained in:
Vladislav Bronzov 2024-09-05 17:42:03 +02:00 committed by GitHub
parent 132e87500e
commit 5d11de4a2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 5 deletions

View File

@ -78,6 +78,7 @@ For now the supported model architectures are the architectures that have been v
- LLaMa
- Mistral
- Qwen2
- Qwen2Moe
## Example usage

View File

@ -79,6 +79,21 @@ GGUF_TENSOR_MAPPING = {
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"qwen2moe": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}
@ -123,6 +138,18 @@ GGUF_CONFIG_MAPPING = {
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"qwen2moe": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
@ -244,7 +271,15 @@ class GGUFLlamaConverter(LlamaConverter):
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None
tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=unk_token,
fuse_unk=True,
byte_fallback=True,
)
)
special_tokens = []
@ -358,6 +393,7 @@ class GGUFQwen2Converter(Qwen2Converter):
GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
}

View File

@ -96,6 +96,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
else:
updated_architecture = architecture
if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"
if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture} not supported")

View File

@ -16,7 +16,12 @@ import tempfile
import unittest
from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
from transformers.testing_utils import (
require_gguf,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_torch_available
@ -33,6 +38,7 @@ class GgufIntegrationTests(unittest.TestCase):
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
@ -59,6 +65,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
@ -298,7 +305,10 @@ class GgufIntegrationTests(unittest.TestCase):
def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
self.mistral_model_id,
gguf_file=self.q4_0_mistral_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
@ -310,7 +320,10 @@ class GgufIntegrationTests(unittest.TestCase):
def test_qwen2_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
self.qwen2_model_id,
gguf_file=self.q4_0_qwen2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
@ -319,6 +332,21 @@ class GgufIntegrationTests(unittest.TestCase):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_qwen2_moe_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_moe_model_id,
gguf_file=self.q4_0_qwen2_moe_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_llama3_q4_0_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
with tempfile.TemporaryDirectory() as tmpdirname:
@ -331,7 +359,10 @@ class GgufIntegrationTests(unittest.TestCase):
def test_llama3_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
self.llama3_model_id,
gguf_file=self.q4_llama3_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)