mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add Qwen2Moe GGUF loading support (#33264)
* update gguf doc, config and tensor mapping * add qwen2moe architecture support, GGUFQwen2MoeConverter and q4 unit tests * apply code style fixes * reformat files * assign GGUFQwen2Converter to qwen2_moe
This commit is contained in:
parent
132e87500e
commit
5d11de4a2f
@ -78,6 +78,7 @@ For now the supported model architectures are the architectures that have been v
|
||||
- LLaMa
|
||||
- Mistral
|
||||
- Qwen2
|
||||
- Qwen2Moe
|
||||
|
||||
## Example usage
|
||||
|
||||
|
@ -79,6 +79,21 @@ GGUF_TENSOR_MAPPING = {
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
"qwen2moe": {
|
||||
"token_embd": "model.embed_tokens",
|
||||
"blk": "model.layers",
|
||||
"ffn_up": "mlp.up_proj",
|
||||
"ffn_down": "mlp.down_proj",
|
||||
"ffn_gate": "mlp.gate_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_q": "self_attn.q_proj",
|
||||
"attn_v": "self_attn.v_proj",
|
||||
"attn_k": "self_attn.k_proj",
|
||||
"attn_output": "self_attn.o_proj",
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -123,6 +138,18 @@ GGUF_CONFIG_MAPPING = {
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
"qwen2moe": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"block_count": "num_hidden_layers",
|
||||
"feed_forward_length": "intermediate_size",
|
||||
"embedding_length": "hidden_size",
|
||||
"rope.dimension_count": None,
|
||||
"rope.freq_base": "rope_theta",
|
||||
"attention.head_count": "num_attention_heads",
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
"tokenizer": {
|
||||
"ggml.bos_token_id": "bos_token_id",
|
||||
"ggml.eos_token_id": "eos_token_id",
|
||||
@ -244,7 +271,15 @@ class GGUFLlamaConverter(LlamaConverter):
|
||||
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
|
||||
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None
|
||||
|
||||
tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
bpe_vocab,
|
||||
merges,
|
||||
unk_token=unk_token,
|
||||
fuse_unk=True,
|
||||
byte_fallback=True,
|
||||
)
|
||||
)
|
||||
|
||||
special_tokens = []
|
||||
|
||||
@ -358,6 +393,7 @@ class GGUFQwen2Converter(Qwen2Converter):
|
||||
GGUF_TO_FAST_CONVERTERS = {
|
||||
"llama": GGUFLlamaConverter,
|
||||
"qwen2": GGUFQwen2Converter,
|
||||
"qwen2_moe": GGUFQwen2Converter,
|
||||
}
|
||||
|
||||
|
||||
|
@ -96,6 +96,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
else:
|
||||
updated_architecture = architecture
|
||||
|
||||
if "qwen2moe" in architecture:
|
||||
updated_architecture = "qwen2_moe"
|
||||
|
||||
if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
|
||||
raise ValueError(f"Architecture {architecture} not supported")
|
||||
|
||||
|
@ -16,7 +16,12 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_gguf,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
@ -33,6 +38,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
|
||||
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
|
||||
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
|
||||
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
|
||||
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
|
||||
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
|
||||
@ -59,6 +65,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
|
||||
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
|
||||
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
|
||||
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
|
||||
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
|
||||
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
|
||||
|
||||
@ -298,7 +305,10 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
def test_mistral_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
|
||||
self.mistral_model_id,
|
||||
gguf_file=self.q4_0_mistral_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
@ -310,7 +320,10 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
def test_qwen2_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
|
||||
self.qwen2_model_id,
|
||||
gguf_file=self.q4_0_qwen2_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
@ -319,6 +332,21 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_qwen2_moe_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.qwen2_moe_model_id,
|
||||
gguf_file=self.q4_0_qwen2_moe_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_llama3_q4_0_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@ -331,7 +359,10 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
def test_llama3_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
|
||||
self.llama3_model_id,
|
||||
gguf_file=self.q4_llama3_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user