mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add gguf support for bloom (#33473)
* add bloom arch support for gguf * apply format * small refactoring, bug fix in GGUF_TENSOR_MAPPING naming * optimize bloom GGUF_TENSOR_MAPPING * implement reverse reshaping for bloom gguf * add qkv weights test * add q_8 test for bloom
This commit is contained in:
parent
3e039d3827
commit
9d200cfbee
@ -80,6 +80,7 @@ For now the supported model architectures are the architectures that have been v
|
||||
- Qwen2
|
||||
- Qwen2Moe
|
||||
- Phi3
|
||||
- Bloom
|
||||
|
||||
## Example usage
|
||||
|
||||
|
@ -328,9 +328,11 @@ class OpenAIGPTConverter(Converter):
|
||||
|
||||
|
||||
class GPT2Converter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
vocab = self.original_tokenizer.encoder
|
||||
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
||||
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
|
||||
if not vocab:
|
||||
vocab = self.original_tokenizer.encoder
|
||||
if not merges:
|
||||
merges = list(self.original_tokenizer.bpe_ranks)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
@ -343,9 +345,11 @@ class GPT2Converter(Converter):
|
||||
)
|
||||
)
|
||||
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
|
||||
add_prefix_space = False
|
||||
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
if self.original_tokenizer.add_bos_token:
|
||||
if getattr(self.original_tokenizer, "add_bos_token", False):
|
||||
bos = self.original_tokenizer.bos_token
|
||||
bos_token_id = self.original_tokenizer.bos_token_id
|
||||
tokenizer.post_processor = processors.TemplateProcessing(
|
||||
|
@ -25,7 +25,7 @@ from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
|
||||
from tokenizers.models import BPE
|
||||
|
||||
from .. import AddedToken
|
||||
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
|
||||
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter
|
||||
from ..utils import logging
|
||||
from ..utils.logging import tqdm
|
||||
|
||||
@ -107,6 +107,19 @@ GGUF_TENSOR_MAPPING = {
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
"bloom": {
|
||||
"token_embd.weight": "transformer.word_embeddings.weight",
|
||||
"token_embd_norm": "transformer.word_embeddings_layernorm",
|
||||
"blk": "transformer.h",
|
||||
"ffn_up": "mlp.dense_h_to_4h",
|
||||
"ffn_down": "mlp.dense_4h_to_h",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_qkv": "self_attention.query_key_value",
|
||||
"attn_output": "self_attention.dense",
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "transformer.ln_f",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -183,6 +196,13 @@ GGUF_CONFIG_MAPPING = {
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
"bloom": {
|
||||
"block_count": "n_layer",
|
||||
"embedding_length": "hidden_size",
|
||||
"attention.head_count": "n_head",
|
||||
"vocab_size": "vocab_size",
|
||||
"attention.layer_norm_epsilon": "layer_norm_epsilon",
|
||||
},
|
||||
}
|
||||
|
||||
GGUF_TOKENIZER_MAPPING = {
|
||||
@ -492,11 +512,24 @@ class GGUFPhi3Converter(LlamaConverter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class GGUFBloomConverter(GPT2Converter):
|
||||
def __init__(self, tokenizer_dict):
|
||||
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
|
||||
self.additional_kwargs = {}
|
||||
|
||||
def converted(self) -> Tokenizer:
|
||||
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
|
||||
merges = self.original_tokenizer.merges
|
||||
tokenizer = super().converted(vocab, merges)
|
||||
return tokenizer
|
||||
|
||||
|
||||
GGUF_TO_FAST_CONVERTERS = {
|
||||
"llama": GGUFLlamaConverter,
|
||||
"qwen2": GGUFQwen2Converter,
|
||||
"qwen2_moe": GGUFQwen2Converter,
|
||||
"phi3": GGUFPhi3Converter,
|
||||
"bloom": GGUFBloomConverter,
|
||||
}
|
||||
|
||||
|
||||
|
@ -169,6 +169,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
elif ".attn_k." in name:
|
||||
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
|
||||
|
||||
if architecture == "bloom" and "attn_qkv" in name:
|
||||
num_heads = parsed_parameters["config"]["n_head"]
|
||||
n_embed = parsed_parameters["config"]["hidden_size"]
|
||||
if "weight" in name:
|
||||
weights = reverse_reshape_weights(weights, num_heads, n_embed)
|
||||
else:
|
||||
weights = reverse_reshape_bias(weights, num_heads, n_embed)
|
||||
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name in name:
|
||||
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
|
||||
@ -191,3 +199,29 @@ def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Opti
|
||||
dim = weights.shape[0] // n_head // 2
|
||||
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
|
||||
return w.swapaxes(2, 1).reshape(weights.shape)
|
||||
|
||||
|
||||
def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
|
||||
# Original reshape implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
|
||||
q, k, v = np.array_split(weights, 3, axis=0)
|
||||
|
||||
q = q.reshape(n_head, n_embed // n_head, n_embed)
|
||||
k = k.reshape(n_head, n_embed // n_head, n_embed)
|
||||
v = v.reshape(n_head, n_embed // n_head, n_embed)
|
||||
qkv_weights = np.stack([q, k, v], axis=1)
|
||||
|
||||
return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
|
||||
|
||||
|
||||
def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
|
||||
# Original reshape implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
|
||||
q_bias, k_bias, v_bias = np.array_split(weights, 3)
|
||||
|
||||
q_bias = q_bias.reshape(n_head, n_embed // n_head)
|
||||
k_bias = k_bias.reshape(n_head, n_embed // n_head)
|
||||
v_bias = v_bias.reshape(n_head, n_embed // n_head)
|
||||
|
||||
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
|
||||
return qkv_bias
|
||||
|
@ -99,8 +99,8 @@ class BloomTokenizerFast(PreTrainedTokenizerFast):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
|
@ -42,6 +42,8 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
|
||||
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
|
||||
bloom_model_id = "afrideva/bloom-560m-GGUF"
|
||||
original_bloom_model_id = "bigscience/bloom-560m"
|
||||
|
||||
# standard quants
|
||||
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
|
||||
@ -69,6 +71,8 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
|
||||
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
|
||||
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
|
||||
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
|
||||
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
|
||||
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
|
||||
|
||||
example_text = "Hello"
|
||||
@ -385,6 +389,62 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_bloom_fp16(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.fp16_bloom_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.bloom_model_id,
|
||||
gguf_file=self.fp16_bloom_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, I just want to say that I am very"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_bloom_q8_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.q8_bloom_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.bloom_model_id,
|
||||
gguf_file=self.q8_bloom_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, I just want to say that I am very"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_bloom_weights_conversion_fp16(self):
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.bloom_model_id,
|
||||
gguf_file=self.fp16_bloom_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
original_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.original_bloom_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quantized_state_dict = quantized_model.state_dict()
|
||||
original_state_dict = original_model.state_dict()
|
||||
|
||||
for (quantized_name, quantized_param), (original_name, original_param) in zip(
|
||||
quantized_state_dict.items(), original_state_dict.items()
|
||||
):
|
||||
if (
|
||||
"self_attention.query_key_value" in quantized_name
|
||||
and "self_attention.query_key_value" in original_name
|
||||
):
|
||||
self.assertTrue(quantized_param.shape == original_param.shape)
|
||||
torch.testing.assert_close(quantized_param, original_param)
|
||||
|
||||
def test_tokenization_xnli(self):
|
||||
import tqdm
|
||||
from datasets import load_dataset
|
||||
|
Loading…
Reference in New Issue
Block a user