mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Improve gguf tensor processing (#34515)
* add tensor processing system to separate logic for models * format refactoring * small fix * make some methods private * move custom methods to processors * refactor tensor processing * format fix
This commit is contained in:
parent
c57eafdaa1
commit
ae5cbf804b
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
@ -55,6 +55,200 @@ GGUF_TO_TRANSFORMERS_MAPPING = {
|
||||
GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())
|
||||
|
||||
|
||||
class GGUFTensor(NamedTuple):
|
||||
weights: np.ndarray
|
||||
name: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
class TensorProcessor:
|
||||
def __init__(self, config=None):
|
||||
self.config = config or {}
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
|
||||
class LlamaTensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
if ".attn_k." in name or ".attn_q." in name:
|
||||
num_heads = self.config.get("num_attention_heads")
|
||||
num_kv_heads = self.config.get("num_key_value_heads")
|
||||
|
||||
if None in (num_heads, num_kv_heads):
|
||||
return GGUFTensor(weights, name, {})
|
||||
if ".attn_q." in name:
|
||||
weights = self._reverse_permute_weights(weights, num_heads, num_heads)
|
||||
elif ".attn_k." in name:
|
||||
weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
def _reverse_permute_weights(
|
||||
self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
# Original permutation implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
|
||||
if num_kv_heads is not None and n_head != num_kv_heads:
|
||||
n_head = num_kv_heads
|
||||
|
||||
dim = weights.shape[0] // n_head // 2
|
||||
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
|
||||
return w.swapaxes(2, 1).reshape(weights.shape)
|
||||
|
||||
|
||||
class Qwen2MoeTensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
if "_exp" in name:
|
||||
tensor_key_mapping = kwargs.get("tensor_key_mapping")
|
||||
parsed_parameters = kwargs.get("parsed_parameters")
|
||||
if tensor_key_mapping:
|
||||
self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
|
||||
return GGUFTensor(weights, None, {})
|
||||
if "ffn_gate_inp_shexp" in name:
|
||||
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
|
||||
# quantized one is (2048)
|
||||
weights = np.expand_dims(weights, axis=0)
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
def _split_moe_expert_tensor(
|
||||
self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
|
||||
):
|
||||
# Original merge implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
|
||||
exp_name = ""
|
||||
if "ffn_gate_exps" in name:
|
||||
exp_name = "gate_proj"
|
||||
elif "ffn_down_exps" in name:
|
||||
exp_name = "down_proj"
|
||||
elif "ffn_up_exps" in name:
|
||||
exp_name = "up_proj"
|
||||
else:
|
||||
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name in name:
|
||||
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
|
||||
w_counter = self.config.get("num_experts", 60)
|
||||
for i in range(0, w_counter):
|
||||
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
|
||||
exp_weight = weights[i]
|
||||
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
|
||||
|
||||
|
||||
class BloomTensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
if "attn_qkv" in name:
|
||||
num_heads = self.config["n_head"]
|
||||
n_embed = self.config["hidden_size"]
|
||||
if "weight" in name:
|
||||
weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
|
||||
else:
|
||||
weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
|
||||
# Original reshape implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
|
||||
q, k, v = np.array_split(weights, 3, axis=0)
|
||||
|
||||
q = q.reshape(n_head, n_embed // n_head, n_embed)
|
||||
k = k.reshape(n_head, n_embed // n_head, n_embed)
|
||||
v = v.reshape(n_head, n_embed // n_head, n_embed)
|
||||
qkv_weights = np.stack([q, k, v], axis=1)
|
||||
|
||||
return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
|
||||
|
||||
def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
|
||||
# Original reshape implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
|
||||
q_bias, k_bias, v_bias = np.array_split(weights, 3)
|
||||
|
||||
q_bias = q_bias.reshape(n_head, n_embed // n_head)
|
||||
k_bias = k_bias.reshape(n_head, n_embed // n_head)
|
||||
v_bias = v_bias.reshape(n_head, n_embed // n_head)
|
||||
|
||||
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
|
||||
return qkv_bias
|
||||
|
||||
|
||||
class T5TensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
bid = None
|
||||
for chunk in name.split("."):
|
||||
if chunk.isdigit():
|
||||
bid = int(chunk)
|
||||
break
|
||||
return GGUFTensor(weights, name, {"bid": bid})
|
||||
|
||||
|
||||
class GPT2TensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
# Original transpose implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
|
||||
if (
|
||||
"attn_qkv.weight" in name
|
||||
or "ffn_down.weight" in name
|
||||
or "ffn_up.weight" in name
|
||||
or "attn_output.weight" in name
|
||||
):
|
||||
weights = weights.T
|
||||
|
||||
# Handle special case for output.weight
|
||||
if name == "output.weight":
|
||||
# output.weight has conflicts with attn_output.weight in name checking
|
||||
# Store the tensor directly and signal to skip further processing
|
||||
name = "lm_head.weight"
|
||||
parsed_parameters = kwargs.get("parsed_parameters", {})
|
||||
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
|
||||
name = None # Signal to skip further processing
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
|
||||
class MambaTensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def process(self, weights, name, **kwargs):
|
||||
if "ssm_d" in name and "bias" not in name and "weight" not in name:
|
||||
# ssm_d has conflicts with ssm_dt in name checking
|
||||
# we have to explicitly check that name is exactly ssm_d
|
||||
name = name.replace("ssm_d", "mixer.D")
|
||||
if "ssm_conv1d.weight" in name:
|
||||
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
|
||||
# quantized one is (5120, 4)
|
||||
weights = np.expand_dims(weights, axis=1)
|
||||
if "ssm_a" in name:
|
||||
# Original exponential implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
|
||||
weights = np.log(-weights)
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
|
||||
TENSOR_PROCESSORS = {
|
||||
"llama": LlamaTensorProcessor,
|
||||
"qwen2moe": Qwen2MoeTensorProcessor,
|
||||
"bloom": BloomTensorProcessor,
|
||||
"t5": T5TensorProcessor,
|
||||
"t5encoder": T5TensorProcessor,
|
||||
"gpt2": GPT2TensorProcessor,
|
||||
"mamba": MambaTensorProcessor,
|
||||
}
|
||||
|
||||
|
||||
def read_field(reader, field):
|
||||
value = reader.fields[field]
|
||||
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
|
||||
@ -177,73 +371,28 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
|
||||
if return_tensors:
|
||||
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]
|
||||
config = parsed_parameters.get("config", {})
|
||||
|
||||
ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
|
||||
processor = ProcessorClass(config=config)
|
||||
|
||||
for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
|
||||
name = tensor.name
|
||||
|
||||
weights = dequantize(tensor.data, tensor.tensor_type)
|
||||
|
||||
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
|
||||
num_heads = parsed_parameters["config"]["num_attention_heads"]
|
||||
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
|
||||
if ".attn_q." in name:
|
||||
weights = reverse_permute_weights(weights, num_heads, num_heads)
|
||||
elif ".attn_k." in name:
|
||||
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
|
||||
result = processor.process(
|
||||
weights=weights,
|
||||
name=name,
|
||||
tensor_key_mapping=tensor_key_mapping,
|
||||
parsed_parameters=parsed_parameters,
|
||||
)
|
||||
|
||||
if architecture == "qwen2moe":
|
||||
if "_exp" in name:
|
||||
split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
|
||||
continue
|
||||
if "ffn_gate_inp_shexp" in name:
|
||||
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
|
||||
# quantized one is (2048)
|
||||
weights = np.expand_dims(weights, axis=0)
|
||||
weights = result.weights
|
||||
name = result.name
|
||||
bid = result.metadata.get("bid")
|
||||
|
||||
if architecture == "bloom" and "attn_qkv" in name:
|
||||
num_heads = parsed_parameters["config"]["n_head"]
|
||||
n_embed = parsed_parameters["config"]["hidden_size"]
|
||||
if "weight" in name:
|
||||
weights = reverse_reshape_weights(weights, num_heads, n_embed)
|
||||
else:
|
||||
weights = reverse_reshape_bias(weights, num_heads, n_embed)
|
||||
|
||||
bid = None
|
||||
if architecture in ("t5", "t5encoder"):
|
||||
for chunk in name.split("."):
|
||||
if chunk.isdigit():
|
||||
bid = int(chunk)
|
||||
break
|
||||
|
||||
if architecture == "gpt2":
|
||||
if (
|
||||
"attn_qkv.weight" in name
|
||||
or "ffn_down.weight" in name
|
||||
or "ffn_up.weight" in name
|
||||
or "attn_output.weight" in name
|
||||
):
|
||||
# Original transpose implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
|
||||
weights = weights.T
|
||||
if name == "output.weight":
|
||||
# output.weight has conflicts with attn_output.weight in name checking
|
||||
# we have to explicitly check that name is exactly output.weight
|
||||
name = "lm_head.weight"
|
||||
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
|
||||
continue
|
||||
if architecture == "mamba":
|
||||
if "ssm_d" in name and "bias" not in name and "weight" not in name:
|
||||
# ssm_d has conflicts with ssm_dt in name checking
|
||||
# we have to explicitly check that name is exactly ssm_d
|
||||
name = name.replace("ssm_d", "mixer.D")
|
||||
if "ssm_conv1d.weight" in name:
|
||||
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
|
||||
# quantized one is (5120, 4)
|
||||
weights = np.expand_dims(weights, axis=1)
|
||||
if "ssm_a" in name:
|
||||
# Original exponential implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
|
||||
weights = np.log(-weights)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name.format(bid=bid) in name:
|
||||
@ -256,64 +405,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
|
||||
|
||||
return parsed_parameters
|
||||
|
||||
|
||||
def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
|
||||
# Original permutation implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
|
||||
if num_kv_heads is not None and n_head != num_kv_heads:
|
||||
n_head = num_kv_heads
|
||||
|
||||
dim = weights.shape[0] // n_head // 2
|
||||
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
|
||||
return w.swapaxes(2, 1).reshape(weights.shape)
|
||||
|
||||
|
||||
def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
|
||||
# Original reshape implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
|
||||
q, k, v = np.array_split(weights, 3, axis=0)
|
||||
|
||||
q = q.reshape(n_head, n_embed // n_head, n_embed)
|
||||
k = k.reshape(n_head, n_embed // n_head, n_embed)
|
||||
v = v.reshape(n_head, n_embed // n_head, n_embed)
|
||||
qkv_weights = np.stack([q, k, v], axis=1)
|
||||
|
||||
return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
|
||||
|
||||
|
||||
def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
|
||||
# Original reshape implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
|
||||
q_bias, k_bias, v_bias = np.array_split(weights, 3)
|
||||
|
||||
q_bias = q_bias.reshape(n_head, n_embed // n_head)
|
||||
k_bias = k_bias.reshape(n_head, n_embed // n_head)
|
||||
v_bias = v_bias.reshape(n_head, n_embed // n_head)
|
||||
|
||||
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
|
||||
return qkv_bias
|
||||
|
||||
|
||||
def split_moe_expert_tensor(
|
||||
weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
|
||||
):
|
||||
# Original merge implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
|
||||
exp_name = ""
|
||||
if "ffn_gate_exps" in name:
|
||||
exp_name = "gate_proj"
|
||||
elif "ffn_down_exps" in name:
|
||||
exp_name = "down_proj"
|
||||
elif "ffn_up_exps" in name:
|
||||
exp_name = "up_proj"
|
||||
else:
|
||||
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name in name:
|
||||
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
|
||||
w_counter = parsed_parameters["config"].get("num_experts", 60)
|
||||
for i in range(0, w_counter):
|
||||
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
|
||||
exp_weight = weights[i]
|
||||
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
|
||||
|
Loading…
Reference in New Issue
Block a user