Improve gguf tensor processing (#34515)

* add tensor processing system to separate logic for models

* format refactoring

* small fix

* make some methods private

* move custom methods to processors

* refactor tensor processing

* format fix
This commit is contained in:
Vladislav Bronzov 2024-11-21 13:40:49 +01:00 committed by GitHub
parent c57eafdaa1
commit ae5cbf804b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import re import re
from typing import Dict, Optional from typing import Dict, NamedTuple, Optional
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
@ -55,6 +55,200 @@ GGUF_TO_TRANSFORMERS_MAPPING = {
GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys()) GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())
class GGUFTensor(NamedTuple):
weights: np.ndarray
name: str
metadata: dict
class TensorProcessor:
def __init__(self, config=None):
self.config = config or {}
def process(self, weights, name, **kwargs):
return GGUFTensor(weights, name, {})
class LlamaTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
def process(self, weights, name, **kwargs):
if ".attn_k." in name or ".attn_q." in name:
num_heads = self.config.get("num_attention_heads")
num_kv_heads = self.config.get("num_key_value_heads")
if None in (num_heads, num_kv_heads):
return GGUFTensor(weights, name, {})
if ".attn_q." in name:
weights = self._reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
return GGUFTensor(weights, name, {})
def _reverse_permute_weights(
self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads
dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)
class Qwen2MoeTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
def process(self, weights, name, **kwargs):
if "_exp" in name:
tensor_key_mapping = kwargs.get("tensor_key_mapping")
parsed_parameters = kwargs.get("parsed_parameters")
if tensor_key_mapping:
self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
return GGUFTensor(weights, None, {})
if "ffn_gate_inp_shexp" in name:
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
# quantized one is (2048)
weights = np.expand_dims(weights, axis=0)
return GGUFTensor(weights, name, {})
def _split_moe_expert_tensor(
self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
):
# Original merge implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
exp_name = ""
if "ffn_gate_exps" in name:
exp_name = "gate_proj"
elif "ffn_down_exps" in name:
exp_name = "down_proj"
elif "ffn_up_exps" in name:
exp_name = "up_proj"
else:
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = self.config.get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
class BloomTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
def process(self, weights, name, **kwargs):
if "attn_qkv" in name:
num_heads = self.config["n_head"]
n_embed = self.config["hidden_size"]
if "weight" in name:
weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
return GGUFTensor(weights, name, {})
def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)
q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)
return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)
q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias
class T5TensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
def process(self, weights, name, **kwargs):
bid = None
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break
return GGUFTensor(weights, name, {"bid": bid})
class GPT2TensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
def process(self, weights, name, **kwargs):
# Original transpose implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
if (
"attn_qkv.weight" in name
or "ffn_down.weight" in name
or "ffn_up.weight" in name
or "attn_output.weight" in name
):
weights = weights.T
# Handle special case for output.weight
if name == "output.weight":
# output.weight has conflicts with attn_output.weight in name checking
# Store the tensor directly and signal to skip further processing
name = "lm_head.weight"
parsed_parameters = kwargs.get("parsed_parameters", {})
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
name = None # Signal to skip further processing
return GGUFTensor(weights, name, {})
class MambaTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
def process(self, weights, name, **kwargs):
if "ssm_d" in name and "bias" not in name and "weight" not in name:
# ssm_d has conflicts with ssm_dt in name checking
# we have to explicitly check that name is exactly ssm_d
name = name.replace("ssm_d", "mixer.D")
if "ssm_conv1d.weight" in name:
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
# quantized one is (5120, 4)
weights = np.expand_dims(weights, axis=1)
if "ssm_a" in name:
# Original exponential implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
weights = np.log(-weights)
return GGUFTensor(weights, name, {})
TENSOR_PROCESSORS = {
"llama": LlamaTensorProcessor,
"qwen2moe": Qwen2MoeTensorProcessor,
"bloom": BloomTensorProcessor,
"t5": T5TensorProcessor,
"t5encoder": T5TensorProcessor,
"gpt2": GPT2TensorProcessor,
"mamba": MambaTensorProcessor,
}
def read_field(reader, field): def read_field(reader, field):
value = reader.fields[field] value = reader.fields[field]
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
@ -177,73 +371,28 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
if return_tensors: if return_tensors:
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size] tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]
config = parsed_parameters.get("config", {})
ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
processor = ProcessorClass(config=config)
for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."): for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
name = tensor.name name = tensor.name
weights = dequantize(tensor.data, tensor.tensor_type) weights = dequantize(tensor.data, tensor.tensor_type)
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): result = processor.process(
num_heads = parsed_parameters["config"]["num_attention_heads"] weights=weights,
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"] name=name,
if ".attn_q." in name: tensor_key_mapping=tensor_key_mapping,
weights = reverse_permute_weights(weights, num_heads, num_heads) parsed_parameters=parsed_parameters,
elif ".attn_k." in name: )
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
if architecture == "qwen2moe": weights = result.weights
if "_exp" in name: name = result.name
split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping) bid = result.metadata.get("bid")
continue
if "ffn_gate_inp_shexp" in name:
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
# quantized one is (2048)
weights = np.expand_dims(weights, axis=0)
if architecture == "bloom" and "attn_qkv" in name: if name is None:
num_heads = parsed_parameters["config"]["n_head"] continue
n_embed = parsed_parameters["config"]["hidden_size"]
if "weight" in name:
weights = reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = reverse_reshape_bias(weights, num_heads, n_embed)
bid = None
if architecture in ("t5", "t5encoder"):
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break
if architecture == "gpt2":
if (
"attn_qkv.weight" in name
or "ffn_down.weight" in name
or "ffn_up.weight" in name
or "attn_output.weight" in name
):
# Original transpose implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
weights = weights.T
if name == "output.weight":
# output.weight has conflicts with attn_output.weight in name checking
# we have to explicitly check that name is exactly output.weight
name = "lm_head.weight"
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
continue
if architecture == "mamba":
if "ssm_d" in name and "bias" not in name and "weight" not in name:
# ssm_d has conflicts with ssm_dt in name checking
# we have to explicitly check that name is exactly ssm_d
name = name.replace("ssm_d", "mixer.D")
if "ssm_conv1d.weight" in name:
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
# quantized one is (5120, 4)
weights = np.expand_dims(weights, axis=1)
if "ssm_a" in name:
# Original exponential implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
weights = np.log(-weights)
for tensor_name in tensor_key_mapping: for tensor_name in tensor_key_mapping:
if tensor_name.format(bid=bid) in name: if tensor_name.format(bid=bid) in name:
@ -256,64 +405,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
return parsed_parameters return parsed_parameters
def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads
dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)
def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)
q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)
return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)
q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias
def split_moe_expert_tensor(
weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
):
# Original merge implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
exp_name = ""
if "ffn_gate_exps" in name:
exp_name = "gate_proj"
elif "ffn_down_exps" in name:
exp_name = "down_proj"
elif "ffn_up_exps" in name:
exp_name = "up_proj"
else:
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = parsed_parameters["config"].get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))