mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Loading GGUF files support (#30391)
* Adds support for loading GGUF files Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: 99991 <99991@users.noreply.github.com> * add q2_k q3_k q5_k support from @99991 * fix tests * Update doc * Style * Docs * fix CI * Update docs/source/en/gguf.md * Update docs/source/en/gguf.md * Compute merges * change logic * add comment for clarity * add comment for clarity * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change logic * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_gguf_pytorch_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * put back comment * add comment about mistral * comments and added tests * fix unconsistent type * more * fix tokenizer * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * address comments about tests and tokenizer + add added_tokens * from_gguf -> gguf_file * replace on docs too --------- Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: 99991 <99991@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
bd9f4d7951
commit
a42844955f
@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
|
||||
# For video model testing
|
||||
RUN python3 -m pip install --no-cache-dir decord av==9.2.0
|
||||
|
||||
# For GGUF tests
|
||||
RUN python3 -m pip install --no-cache-dir gguf
|
||||
|
||||
# Some slow tests require bnb
|
||||
RUN python3 -m pip install --no-cache-dir bitsandbytes
|
||||
|
||||
|
@ -137,6 +137,8 @@
|
||||
title: Troubleshoot
|
||||
- local: hf_quantizer
|
||||
title: Contribute new quantization method
|
||||
- local: gguf
|
||||
title: Interoperability with GGUF files
|
||||
title: Developer guides
|
||||
- sections:
|
||||
- local: performance
|
||||
|
96
docs/source/en/gguf.md
Normal file
96
docs/source/en/gguf.md
Normal file
@ -0,0 +1,96 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# GGUF and interaction with Transformers
|
||||
|
||||
The GGUF file format is used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and other
|
||||
libraries that depend on it, like the very popular [llama.cpp](https://github.com/ggerganov/llama.cpp) or
|
||||
[whisper.cpp](https://github.com/ggerganov/whisper.cpp).
|
||||
|
||||
It is a file format [supported by the Hugging Face Hub](https://huggingface.co/docs/hub/en/gguf) with features
|
||||
allowing for quick inspection of tensors and metadata within the file.
|
||||
|
||||
This file format is designed as a "single-file-format" where a single file usually contains both the configuration
|
||||
attributes, the tokenizer vocabulary and other attributes, as well as all tensors to be loaded in the model. These
|
||||
files come in different formats according to the quantization type of the file. We briefly go over some of them
|
||||
[here](https://huggingface.co/docs/hub/en/gguf#quantization-types).
|
||||
|
||||
## Support within Transformers
|
||||
|
||||
We have added the ability to load `gguf` files within `transformers` in order to offer further training/fine-tuning
|
||||
capabilities to gguf models, before converting back those models to `gguf` to use within the `ggml` ecosystem. When
|
||||
loading a model, we first dequantize it to fp32, before loading the weights to be used in PyTorch.
|
||||
|
||||
> [!NOTE]
|
||||
> The support is still very exploratory and we welcome contributions in order to solidify it across quantization types
|
||||
> and model architectures.
|
||||
|
||||
For now, here are the supported model architectures and quantization types:
|
||||
|
||||
### Supported quantization types
|
||||
|
||||
The initial supported quantization types are decided according to the popular quantized files that have been shared
|
||||
on the Hub.
|
||||
|
||||
- F32
|
||||
- Q2_K
|
||||
- Q3_K
|
||||
- Q4_0
|
||||
- Q4_K
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
- Q8_0
|
||||
|
||||
We take example from the excellent [99991/pygguf](https://github.com/99991/pygguf) Python parser to dequantize the
|
||||
weights.
|
||||
|
||||
### Supported model architectures
|
||||
|
||||
For now the supported model architectures are the architectures that have been very popular on the Hub, namely:
|
||||
|
||||
- LLaMa
|
||||
- Mistral
|
||||
|
||||
## Example usage
|
||||
|
||||
In order to load `gguf` files in `transformers`, you should specify the `gguf_file` argument to the `from_pretrained`
|
||||
methods of both tokenizers and models. Here is how one would load a tokenizer and a model, which can be loaded
|
||||
from the exact same file:
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
filename = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename)
|
||||
```
|
||||
|
||||
Now you have access to the full, unquantized version of the model in the PyTorch ecosystem, where you can combine it
|
||||
with a plethora of other tools.
|
||||
|
||||
In order to convert back to a `gguf` file, we recommend using the
|
||||
[`convert-hf-to-gguf.py` file](https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py) from llama.cpp.
|
||||
|
||||
Here's how you would complete the script above to save the model and export it back to `gguf`:
|
||||
|
||||
```py
|
||||
tokenizer.save_pretrained('directory')
|
||||
model.save_pretrained('directory')
|
||||
|
||||
!python ${path_to_llama_cpp}/convert-hf-to-gguf.py ${directory}
|
||||
```
|
@ -27,6 +27,7 @@ from packaging import version
|
||||
|
||||
from . import __version__
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
PushToHubMixin,
|
||||
@ -658,6 +659,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
gguf_file = kwargs.get("gguf_file", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
|
||||
@ -676,10 +679,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
resolved_config_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
configuration_file = pretrained_model_name_or_path
|
||||
configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
|
||||
resolved_config_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
|
||||
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
@ -712,8 +715,12 @@ class PretrainedConfig(PushToHubMixin):
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
if gguf_file:
|
||||
config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
|
||||
else:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
config_dict["_commit_hash"] = commit_hash
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
|
@ -44,6 +44,14 @@ _import_structure = {
|
||||
"unset_hf_deepspeed_config",
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
"GGUF_TENSOR_MAPPING",
|
||||
"GGUF_TOKENIZER_MAPPING",
|
||||
"_gguf_parse_value",
|
||||
"load_dequant_gguf_tensor",
|
||||
"load_gguf",
|
||||
],
|
||||
"hqq": ["prepare_for_hqq_linear"],
|
||||
"integration_utils": [
|
||||
"INTEGRATION_TO_CALLBACK",
|
||||
@ -116,6 +124,14 @@ if TYPE_CHECKING:
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .ggml import (
|
||||
GGUF_CONFIG_MAPPING,
|
||||
GGUF_TENSOR_MAPPING,
|
||||
GGUF_TOKENIZER_MAPPING,
|
||||
_gguf_parse_value,
|
||||
load_dequant_gguf_tensor,
|
||||
load_gguf,
|
||||
)
|
||||
from .hqq import prepare_for_hqq_linear
|
||||
from .integration_utils import (
|
||||
INTEGRATION_TO_CALLBACK,
|
||||
|
584
src/transformers/integrations/ggml.py
Normal file
584
src/transformers/integrations/ggml.py
Normal file
@ -0,0 +1,584 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
|
||||
# https://github.com/99991/pygguf
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Integration with GGML / The file is copied and adapted from https://github.com/99991/pygguf
|
||||
with extra methods beings exposed
|
||||
"""
|
||||
from array import array
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import Tokenizer, decoders
|
||||
from tokenizers.models import BPE
|
||||
|
||||
from .. import AddedToken
|
||||
from ..convert_slow_tokenizer import LlamaConverter
|
||||
from ..utils import logging
|
||||
from ..utils.logging import tqdm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
|
||||
GGML_TYPES = {
|
||||
"F32": 0,
|
||||
"Q4_0": 2,
|
||||
"Q8_0": 8,
|
||||
"Q2_K": 10,
|
||||
"Q3_K": 11,
|
||||
"Q4_K": 12,
|
||||
"Q5_K": 13,
|
||||
"Q6_K": 14,
|
||||
}
|
||||
|
||||
# The Blocksizes are reported in bytes
|
||||
# Check out: https://github.com/ggerganov/llama.cpp/blob/8a56075b07a8b571bf95a912ffdce4c928c2b414/gguf-py/gguf/constants.py#L801
|
||||
GGML_BLOCK_SIZES = {
|
||||
"Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales
|
||||
"Q4_K": 144,
|
||||
# Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales
|
||||
"Q4_0": 2 + 16,
|
||||
"Q6_K": 210,
|
||||
# See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9
|
||||
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
|
||||
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
|
||||
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
|
||||
}
|
||||
|
||||
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
|
||||
DATA_TYPES = {
|
||||
"uint32": 4,
|
||||
"int32": 5,
|
||||
"float32": 6,
|
||||
"bool": 7,
|
||||
"string": 8,
|
||||
"array": 9,
|
||||
"uint64": 10,
|
||||
}
|
||||
|
||||
GGUF_TENSOR_MAPPING = {
|
||||
"llama": {
|
||||
"token_embd": "model.embed_tokens",
|
||||
"blk": "model.layers",
|
||||
"ffn_up": "mlp.up_proj",
|
||||
"ffn_down": "mlp.down_proj",
|
||||
"ffn_gate": "mlp.gate_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_q": "self_attn.q_proj",
|
||||
"attn_v": "self_attn.v_proj",
|
||||
"attn_k": "self_attn.k_proj",
|
||||
"attn_output": "self_attn.o_proj",
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
"mistral": {
|
||||
"token_embd": "model.embed_tokens",
|
||||
"blk": "model.layers",
|
||||
"ffn_up": "mlp.up_proj",
|
||||
"ffn_down": "mlp.down_proj",
|
||||
"ffn_gate": "mlp.gate_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_q": "self_attn.q_proj",
|
||||
"attn_v": "self_attn.v_proj",
|
||||
"attn_k": "self_attn.k_proj",
|
||||
"attn_output": "self_attn.o_proj",
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
GGUF_CONFIG_MAPPING = {
|
||||
"general": {
|
||||
"architecture": "model_type",
|
||||
"name": "_model_name_or_path",
|
||||
},
|
||||
"llama": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"block_count": "num_hidden_layers",
|
||||
"feed_forward_length": "intermediate_size",
|
||||
"embedding_length": "hidden_size",
|
||||
"rope.dimension_count": None,
|
||||
"rope.freq_base": "rope_theta",
|
||||
"attention.head_count": "num_attention_heads",
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
"mistral": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"block_count": "num_hidden_layers",
|
||||
"feed_forward_length": "intermediate_size",
|
||||
"embedding_length": "hidden_size",
|
||||
"rope.dimension_count": None,
|
||||
"rope.freq_base": "rope_theta",
|
||||
"attention.head_count": "num_attention_heads",
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
"tokenizer": {
|
||||
"ggml.model": "model_type",
|
||||
"ggml.bos_token_id": "bos_token_id",
|
||||
"ggml.eos_token_id": "eos_token_id",
|
||||
"ggml.unknown_token_id": "unk_token_id",
|
||||
"ggml.padding_token_id": "pad_token_id",
|
||||
},
|
||||
}
|
||||
|
||||
GGUF_TOKENIZER_MAPPING = {
|
||||
"tokenizer": {
|
||||
"ggml.model": "tokenizer_type",
|
||||
"ggml.tokens": "tokens",
|
||||
"ggml.scores": "scores",
|
||||
"ggml.token_type": "token_type",
|
||||
"ggml.merges": "merges",
|
||||
"ggml.bos_token_id": "bos_token_id",
|
||||
"ggml.eos_token_id": "eos_token_id",
|
||||
"ggml.unknown_token_id": "unk_token_id",
|
||||
"ggml.padding_token_id": "pad_token_id",
|
||||
"ggml.add_space_prefix": "add_prefix_space",
|
||||
},
|
||||
"tokenizer_config": {
|
||||
"chat_template": "chat_template",
|
||||
"ggml.model": "model_type",
|
||||
"ggml.bos_token_id": "bos_token_id",
|
||||
"ggml.eos_token_id": "eos_token_id",
|
||||
"ggml.unknown_token_id": "unk_token_id",
|
||||
"ggml.padding_token_id": "pad_token_id",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _gguf_parse_value(_value, data_type):
|
||||
if not isinstance(data_type, list):
|
||||
data_type = [data_type]
|
||||
if len(data_type) == 1:
|
||||
data_type = data_type[0]
|
||||
array_data_type = None
|
||||
else:
|
||||
if data_type[0] != 9:
|
||||
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
|
||||
data_type, array_data_type = data_type
|
||||
|
||||
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
|
||||
_value = int(_value[0])
|
||||
elif data_type in [6, 12]:
|
||||
_value = float(_value[0])
|
||||
elif data_type in [7]:
|
||||
_value = bool(_value[0])
|
||||
elif data_type in [8]:
|
||||
_value = array("B", list(_value)).tobytes().decode()
|
||||
elif data_type in [9]:
|
||||
_value = _gguf_parse_value(_value, array_data_type)
|
||||
return _value
|
||||
|
||||
|
||||
def dequantize_q4_k(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
|
||||
block_size = GGML_BLOCK_SIZES["Q4_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
# Casting to float32 because float16 is very slow on CPU
|
||||
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
|
||||
|
||||
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
|
||||
factors = scale_factors * np.concatenate(
|
||||
[qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1
|
||||
)
|
||||
offsets = scale_offsets * np.concatenate(
|
||||
[qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1
|
||||
)
|
||||
|
||||
# Interleave low and high quantized bits
|
||||
qs2 = np.stack([qs2 & 0xF, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
|
||||
# Dequantize final weights using scales and offsets
|
||||
return factors * qs2 - offsets
|
||||
|
||||
|
||||
def dequantize_q4_0(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11
|
||||
block_size = GGML_BLOCK_SIZES["Q4_0"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
# The scales are stored on the first 2 bytes and the rest corresponds to the quants
|
||||
scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
|
||||
# scales = np.nan_to_num(scales)
|
||||
# the rest of the bytes corresponds to the quants - we discard the first two bytes
|
||||
quants = data_u8[:, 2:]
|
||||
|
||||
ql = (quants[:, :] & 0xF).astype(np.int8) - 8
|
||||
qr = (quants[:, :] >> 4).astype(np.int8) - 8
|
||||
|
||||
# Use hstack
|
||||
quants = np.hstack([ql, qr])
|
||||
|
||||
return (scales * quants).astype(np.float32)
|
||||
|
||||
|
||||
def dequantize_q6_k(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152
|
||||
block_size = GGML_BLOCK_SIZES["Q6_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
|
||||
|
||||
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
|
||||
|
||||
# TODO use uint8 and cast later?
|
||||
ql = data_u8[:, :128].astype(np.int16)
|
||||
qh = data_u8[:, 128:192].astype(np.int16)
|
||||
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
|
||||
|
||||
# Unpack bits, subtraction requires signed data type
|
||||
q1 = (ql[:, :32] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
|
||||
q2 = (ql[:, 32:64] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
|
||||
q3 = (ql[:, :32] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
|
||||
q4 = (ql[:, 32:64] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
|
||||
q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
|
||||
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
|
||||
q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
|
||||
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
|
||||
|
||||
# Dequantize
|
||||
return scales * np.concatenate(
|
||||
[
|
||||
sc[:, 0] * q1[:, :16],
|
||||
sc[:, 1] * q1[:, 16:],
|
||||
sc[:, 2] * q2[:, :16],
|
||||
sc[:, 3] * q2[:, 16:],
|
||||
sc[:, 4] * q3[:, :16],
|
||||
sc[:, 5] * q3[:, 16:],
|
||||
sc[:, 6] * q4[:, :16],
|
||||
sc[:, 7] * q4[:, 16:],
|
||||
sc[:, 8] * q5[:, :16],
|
||||
sc[:, 9] * q5[:, 16:],
|
||||
sc[:, 10] * q6[:, :16],
|
||||
sc[:, 11] * q6[:, 16:],
|
||||
sc[:, 12] * q7[:, :16],
|
||||
sc[:, 13] * q7[:, 16:],
|
||||
sc[:, 14] * q8[:, :16],
|
||||
sc[:, 15] * q8[:, 16:],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
|
||||
def dequantize_q8_0(data):
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
|
||||
block_size = GGML_BLOCK_SIZES["Q8_0"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
|
||||
|
||||
return scales * qs
|
||||
|
||||
|
||||
def dequantize_q2_k(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q2_K"]
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"])
|
||||
|
||||
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
|
||||
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
|
||||
|
||||
tmp = np.stack(
|
||||
[
|
||||
qs[:, 00:16] >> 0,
|
||||
qs[:, 16:32] >> 0,
|
||||
qs[:, 00:16] >> 2,
|
||||
qs[:, 16:32] >> 2,
|
||||
qs[:, 00:16] >> 4,
|
||||
qs[:, 16:32] >> 4,
|
||||
qs[:, 00:16] >> 6,
|
||||
qs[:, 16:32] >> 6,
|
||||
qs[:, 32:48] >> 0,
|
||||
qs[:, 48:64] >> 0,
|
||||
qs[:, 32:48] >> 2,
|
||||
qs[:, 48:64] >> 2,
|
||||
qs[:, 32:48] >> 4,
|
||||
qs[:, 48:64] >> 4,
|
||||
qs[:, 32:48] >> 6,
|
||||
qs[:, 48:64] >> 6,
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q3_K"]
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"])
|
||||
|
||||
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
|
||||
bits = 4 ^ (bits << 2)
|
||||
qs = data_u8[:, 32 : 32 + 64].astype(np.int16)
|
||||
a, b, c = data_u8[:, 96 : 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
|
||||
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
|
||||
scales[:, 0] = (a & 15) | ((c & 3) << 4)
|
||||
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
|
||||
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
|
||||
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
|
||||
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
|
||||
|
||||
return (
|
||||
d
|
||||
* (scales - 32)
|
||||
* np.stack(
|
||||
[
|
||||
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
|
||||
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
|
||||
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
|
||||
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
|
||||
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
|
||||
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
|
||||
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
|
||||
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
|
||||
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
|
||||
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
|
||||
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
|
||||
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
|
||||
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
|
||||
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
|
||||
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
|
||||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def dequantize_q5_k(data):
|
||||
# C implementation
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129
|
||||
# C struct definition
|
||||
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_K"]
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"])
|
||||
|
||||
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
|
||||
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
|
||||
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qh = data_u8[:, 16 : 16 + 32].reshape(num_blocks, 32, 1)
|
||||
qs = data_u8[:, 48 : 48 + 128].reshape(num_blocks, 4, 32)
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
qs_hi_4 = qs >> 4
|
||||
qs_lo_4 = qs & 15
|
||||
|
||||
scales_lo_6 = scales[:, :8] & 63
|
||||
scales_hi_6 = scales[:, :8] >> 6
|
||||
scales_lo_4 = scales[:, 8:] & 15
|
||||
scales_hi_4 = scales[:, 8:] >> 4
|
||||
|
||||
m1 = dmin * scales_lo_6[:, 4]
|
||||
m2 = dmin * scales_lo_6[:, 5]
|
||||
m3 = dmin * scales_lo_6[:, 6]
|
||||
m4 = dmin * scales_lo_6[:, 7]
|
||||
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
|
||||
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
|
||||
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
|
||||
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
|
||||
|
||||
d1 = d * scales_lo_6[:, 0]
|
||||
d2 = d * scales_lo_6[:, 1]
|
||||
d3 = d * scales_lo_6[:, 2]
|
||||
d4 = d * scales_lo_6[:, 3]
|
||||
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
|
||||
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
|
||||
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
|
||||
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
|
||||
|
||||
return np.concatenate(
|
||||
[
|
||||
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
|
||||
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
|
||||
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
|
||||
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
|
||||
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
|
||||
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
|
||||
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
|
||||
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
|
||||
def load_dequant_gguf_tensor(shape, ggml_type, data):
|
||||
if ggml_type == GGML_TYPES["F32"]:
|
||||
values = data
|
||||
elif ggml_type == GGML_TYPES["Q8_0"]:
|
||||
values = dequantize_q8_0(data)
|
||||
elif ggml_type == GGML_TYPES["Q4_0"]:
|
||||
values = dequantize_q4_0(data)
|
||||
elif ggml_type == GGML_TYPES["Q4_K"]:
|
||||
values = dequantize_q4_k(data)
|
||||
elif ggml_type == GGML_TYPES["Q6_K"]:
|
||||
values = dequantize_q6_k(data)
|
||||
elif ggml_type == GGML_TYPES["Q2_K"]:
|
||||
values = dequantize_q2_k(data)
|
||||
elif ggml_type == GGML_TYPES["Q3_K"]:
|
||||
values = dequantize_q3_k(data)
|
||||
elif ggml_type == GGML_TYPES["Q5_K"]:
|
||||
values = dequantize_q5_k(data)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose"
|
||||
)
|
||||
|
||||
return values.reshape(shape[::-1])
|
||||
|
||||
|
||||
class GGUFTokenizerSkeleton:
|
||||
def __init__(self, dict_):
|
||||
for k, v in dict_.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
if not hasattr(self, "tokens") or not hasattr(self, "scores"):
|
||||
raise ValueError("tokens and scores need to be passed for a LLaMa tokenizer to be instantiated.")
|
||||
else:
|
||||
tokens = self.tokens
|
||||
scores = self.scores
|
||||
vocab = {t: scores[i] for i, t in enumerate(tokens)}
|
||||
|
||||
if not hasattr(self, "merges"):
|
||||
logger.warning("Merges were not in checkpoint, building merges on the fly.")
|
||||
merges = []
|
||||
for merge, piece_score in tqdm(vocab.items()):
|
||||
local = []
|
||||
for index in range(1, len(merge)):
|
||||
piece_l, piece_r = merge[:index], merge[index:]
|
||||
if piece_l in tokens and piece_r in tokens:
|
||||
local.append((piece_l, piece_r, piece_score))
|
||||
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]), reverse=True)
|
||||
merges.extend(local)
|
||||
merges = sorted(merges, key=lambda val: val[2], reverse=True)
|
||||
merges = [(val[0], val[1]) for val in merges]
|
||||
self.merges = merges
|
||||
else:
|
||||
self.merges = [tuple(merge.split(" ")) for merge in self.merges]
|
||||
|
||||
if not hasattr(self, "added_tokens"):
|
||||
self.added_tokens = []
|
||||
|
||||
|
||||
class GGUFLlamaConverter(LlamaConverter):
|
||||
def __init__(self, tokenizer_dict):
|
||||
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
|
||||
self.original_tokenizer = self.proto
|
||||
|
||||
def vocab(self, proto):
|
||||
return list(zip(proto.tokens, proto.scores))
|
||||
|
||||
def merges(self, proto):
|
||||
return proto.merges
|
||||
|
||||
def tokenizer(self, proto):
|
||||
vocab_scores = self.vocab(self.proto)
|
||||
merges = self.merges(self.proto)
|
||||
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
||||
tokenizer = Tokenizer(
|
||||
BPE(bpe_vocab, merges, unk_token=proto.tokens[proto.unk_token_id], fuse_unk=True, byte_fallback=True)
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
[
|
||||
AddedToken("<unk>", normalized=False, special=True),
|
||||
AddedToken("<s>", normalized=False, special=True),
|
||||
AddedToken("</s>", normalized=False, special=True),
|
||||
]
|
||||
)
|
||||
|
||||
if len(self.proto.added_tokens) != 0:
|
||||
tokenizer.add_special_tokens(
|
||||
[AddedToken(added_token, normalized=False, special=False) for added_token in self.added_tokens]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def decoder(self, replacement, add_prefix_space):
|
||||
sequence = [
|
||||
decoders.ByteFallback(),
|
||||
decoders.Fuse(),
|
||||
decoders.Replace("▁", " "),
|
||||
]
|
||||
if add_prefix_space:
|
||||
sequence += [decoders.Strip(content=" ", left=1)]
|
||||
return decoders.Sequence(sequence)
|
||||
|
||||
|
||||
GGUF_TO_FAST_CONVERTERS = {
|
||||
"llama": GGUFLlamaConverter,
|
||||
}
|
||||
|
||||
|
||||
def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer:
|
||||
"""
|
||||
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
|
||||
|
||||
Args:
|
||||
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
|
||||
Instance of a slow tokenizer to convert in the backend tokenizer for
|
||||
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
|
||||
|
||||
Return:
|
||||
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
|
||||
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
|
||||
"""
|
||||
tokenizer_class_name = tokenizer_dict["tokenizer_type"]
|
||||
converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name]
|
||||
return converter_class(tokenizer_dict).converted()
|
165
src/transformers/modeling_gguf_pytorch_utils.py
Normal file
165
src/transformers/modeling_gguf_pytorch_utils.py
Normal file
@ -0,0 +1,165 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
|
||||
# https://github.com/99991/pygguf
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from .integrations import (
|
||||
GGUF_CONFIG_MAPPING,
|
||||
GGUF_TENSOR_MAPPING,
|
||||
GGUF_TOKENIZER_MAPPING,
|
||||
_gguf_parse_value,
|
||||
load_dequant_gguf_tensor,
|
||||
)
|
||||
from .utils import is_torch_available
|
||||
from .utils.logging import get_logger
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
GGUF_TO_TRANSFORMERS_MAPPING = {
|
||||
"ignore": {
|
||||
"GGUF": {
|
||||
"version": "version",
|
||||
"tensor_count": "tensor_count",
|
||||
"kv_count": "kv_count",
|
||||
},
|
||||
"general": {"file_type": "file_type", "quantization_version": "quantization_version"},
|
||||
},
|
||||
"config": GGUF_CONFIG_MAPPING,
|
||||
"tensors": GGUF_TENSOR_MAPPING,
|
||||
"tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]},
|
||||
"tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]},
|
||||
}
|
||||
|
||||
GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())
|
||||
|
||||
|
||||
def read_field(reader, field):
|
||||
value = reader.fields[field]
|
||||
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
|
||||
|
||||
|
||||
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
"""
|
||||
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed
|
||||
tokenizer and config attributes.
|
||||
|
||||
Args:
|
||||
gguf_checkpoint_path (`str`):
|
||||
The path the to GGUF file to load
|
||||
return_tensors (`bool`, defaults to `True`):
|
||||
Whether to read the tensors from the file and return them. Not doing so is faster
|
||||
and only loads the metadata in memory.
|
||||
"""
|
||||
try:
|
||||
from gguf import GGUFReader
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
|
||||
)
|
||||
raise
|
||||
|
||||
reader = GGUFReader(gguf_checkpoint_path)
|
||||
fields = reader.fields
|
||||
reader_keys = list(fields.keys())
|
||||
|
||||
parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING}
|
||||
|
||||
architecture = read_field(reader, "general.architecture")[0]
|
||||
model_name = read_field(reader, "general.name")
|
||||
|
||||
# in llama.cpp mistral models use the same architecture as llama. We need
|
||||
# to add this patch to ensure things work correctly on our side.
|
||||
if "llama" in architecture and "mistral" in model_name:
|
||||
updated_architecture = "mistral"
|
||||
else:
|
||||
updated_architecture = architecture
|
||||
|
||||
if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
|
||||
raise ValueError(f"Architecture {architecture} not supported")
|
||||
|
||||
# List all key-value pairs in a columnized format
|
||||
for gguf_key, field in reader.fields.items():
|
||||
gguf_key = gguf_key.replace(architecture, updated_architecture)
|
||||
split = gguf_key.split(".")
|
||||
prefix = split[0]
|
||||
config_key = ".".join(split[1:])
|
||||
|
||||
value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data]
|
||||
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
|
||||
if isinstance(value, str) and architecture in value:
|
||||
value = value.replace(architecture, updated_architecture)
|
||||
|
||||
for parameter in GGUF_TO_TRANSFORMERS_MAPPING:
|
||||
parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter]
|
||||
if prefix in parameter_renames and config_key in parameter_renames[prefix]:
|
||||
renamed_config_key = parameter_renames[prefix][config_key]
|
||||
if renamed_config_key == -1:
|
||||
continue
|
||||
|
||||
if renamed_config_key is not None:
|
||||
parsed_parameters[parameter][renamed_config_key] = value
|
||||
|
||||
if gguf_key in reader_keys:
|
||||
reader_keys.remove(gguf_key)
|
||||
|
||||
if gguf_key in reader_keys:
|
||||
logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
|
||||
|
||||
if return_tensors:
|
||||
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture]
|
||||
|
||||
for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
|
||||
renamed_tensor_name = tensor.name
|
||||
|
||||
for tensor_name_mapping in GGUF_TO_TRANSFORMERS_MAPPING["tensors"]:
|
||||
if tensor_name_mapping in renamed_tensor_name:
|
||||
renamed_tensor_name = renamed_tensor_name.replace(
|
||||
tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping]
|
||||
)
|
||||
|
||||
shape = tensor.shape
|
||||
name = tensor.name
|
||||
|
||||
weights = load_dequant_gguf_tensor(shape=shape, ggml_type=tensor.tensor_type, data=tensor.data)
|
||||
|
||||
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
|
||||
num_heads = parsed_parameters["config"]["num_attention_heads"]
|
||||
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
|
||||
weights = weights.reshape(tmp_shape)
|
||||
weights = weights.transpose(0, 2, 1, 3)
|
||||
weights = weights.reshape(shape[::-1])
|
||||
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name in name:
|
||||
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
|
||||
|
||||
# Use copy to avoid errors with numpy and pytorch
|
||||
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
|
||||
|
||||
if len(reader_keys) > 0:
|
||||
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
|
||||
|
||||
return parsed_parameters
|
@ -2993,6 +2993,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
adapter_name = kwargs.pop("adapter_name", "default")
|
||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
# Cache path to the GGUF file
|
||||
gguf_path = None
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
@ -3156,6 +3160,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
|
||||
if kwarg_attn_imp is not None:
|
||||
config._attn_implementation = kwarg_attn_imp
|
||||
|
||||
model_kwargs = kwargs
|
||||
|
||||
pre_quantized = getattr(config, "quantization_config", None) is not None
|
||||
@ -3197,7 +3202,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules = None
|
||||
use_keep_in_fp32_modules = False
|
||||
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if gguf_file is not None and hf_quantizer is not None:
|
||||
raise ValueError(
|
||||
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
|
||||
)
|
||||
|
||||
if pretrained_model_name_or_path is not None and gguf_file is None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if is_local:
|
||||
@ -3439,6 +3449,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
resolved_archive_file = archive_file
|
||||
else:
|
||||
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||
elif gguf_file:
|
||||
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
||||
|
||||
# Case 1: the GGUF file is present locally
|
||||
if os.path.isfile(gguf_file):
|
||||
gguf_path = gguf_file
|
||||
# Case 2: The GGUF path is a location on the Hub
|
||||
# Load from URL or cache if already cached
|
||||
else:
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"user_agent": user_agent,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"_raise_exceptions_for_gated_repo": False,
|
||||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_commit_hash": commit_hash,
|
||||
}
|
||||
|
||||
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
|
||||
|
||||
state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"]
|
||||
|
||||
resolved_archive_file = None
|
||||
is_sharded = False
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
@ -3533,7 +3573,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
else:
|
||||
loaded_state_dict_keys = list(state_dict.keys())
|
||||
if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()):
|
||||
|
||||
if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
|
||||
# In case some weights need to be kept in float32 and accelerate is not installed,
|
||||
# we later on want to take the path where state_dict is not None, that is the one
|
||||
# that do not require accelerate.
|
||||
@ -3679,6 +3720,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# restore default dtype
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
(
|
||||
model,
|
||||
missing_keys,
|
||||
@ -3702,6 +3744,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
dtype=torch_dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_path=gguf_path,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
@ -3795,9 +3838,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
dtype=None,
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
gguf_path=None,
|
||||
):
|
||||
is_safetensors = False
|
||||
is_quantized = hf_quantizer is not None
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
archive_file = (
|
||||
@ -4055,6 +4101,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
for p, f in weight_map.items()
|
||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||
}
|
||||
else:
|
||||
offload_index = None
|
||||
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
@ -4066,11 +4114,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||
offload_index = None
|
||||
else:
|
||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||
|
||||
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
||||
if gguf_path:
|
||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
loaded_keys,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
state_dict_folder=state_dict_folder,
|
||||
state_dict_index=state_dict_index,
|
||||
dtype=dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
else:
|
||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||
|
||||
else:
|
||||
# This should always be a list but, just to be sure.
|
||||
if not isinstance(resolved_archive_file, list):
|
||||
resolved_archive_file = [resolved_archive_file]
|
||||
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import (
|
||||
@ -781,6 +782,7 @@ class AutoTokenizer:
|
||||
use_fast = kwargs.pop("use_fast", True)
|
||||
tokenizer_type = kwargs.pop("tokenizer_type", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
gguf_file = kwargs.get("gguf_file", None)
|
||||
|
||||
# First, let's see whether the tokenizer_type is passed so that we can leverage it
|
||||
if tokenizer_type is not None:
|
||||
@ -827,9 +829,14 @@ class AutoTokenizer:
|
||||
# If that did not work, let's try to use the config.
|
||||
if config_tokenizer_class is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
if gguf_file:
|
||||
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
|
||||
config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
|
||||
config = AutoConfig.for_model(**config_dict)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
config_tokenizer_class = config.tokenizer_class
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
|
||||
@ -887,6 +894,7 @@ class AutoTokenizer:
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
if model_type is not None:
|
||||
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
|
||||
|
||||
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
|
||||
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
|
@ -74,6 +74,7 @@ from .utils import (
|
||||
is_ftfy_available,
|
||||
is_g2p_en_available,
|
||||
is_galore_torch_available,
|
||||
is_gguf_available,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
is_jinja_available,
|
||||
@ -376,6 +377,13 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_gguf(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_gguf_available(), "test requires gguf")(test_case)
|
||||
|
||||
|
||||
def require_fsdp(test_case, min_version: str = "1.12.0"):
|
||||
"""
|
||||
Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
|
||||
|
@ -1968,6 +1968,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
gguf_file = kwargs.get("gguf_file", None)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@ -1995,7 +1996,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
single_file_id = None
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
if len(cls.vocab_files_names) > 1:
|
||||
if len(cls.vocab_files_names) > 1 and not gguf_file:
|
||||
raise ValueError(
|
||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
|
||||
"supported for this tokenizer. Use a model identifier or the path to a directory instead."
|
||||
@ -2010,42 +2011,45 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
vocab_files[file_id] = pretrained_model_name_or_path
|
||||
single_file_id = file_id
|
||||
else:
|
||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||
additional_files_names = {
|
||||
"added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
|
||||
"tokenizer_file": FULL_TOKENIZER_FILE,
|
||||
}
|
||||
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
||||
if "tokenizer_file" in vocab_files:
|
||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||
if resolved_config_file is not None:
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
tokenizer_config = json.load(reader)
|
||||
if "fast_tokenizer_files" in tokenizer_config:
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
||||
vocab_files["tokenizer_file"] = fast_tokenizer_file
|
||||
if gguf_file:
|
||||
vocab_files["vocab_file"] = gguf_file
|
||||
else:
|
||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||
additional_files_names = {
|
||||
"added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
|
||||
"tokenizer_file": FULL_TOKENIZER_FILE,
|
||||
}
|
||||
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
||||
if "tokenizer_file" in vocab_files:
|
||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||
if resolved_config_file is not None:
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
tokenizer_config = json.load(reader)
|
||||
if "fast_tokenizer_files" in tokenizer_config:
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
||||
vocab_files["tokenizer_file"] = fast_tokenizer_file
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
resolved_vocab_files = {}
|
||||
@ -2084,7 +2088,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
"files are necessary for the tokenizer to operate."
|
||||
)
|
||||
|
||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
|
||||
# loaded directly from the GGUF file.
|
||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not gguf_file:
|
||||
raise EnvironmentError(
|
||||
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
@ -2133,8 +2139,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
|
||||
# file or if `from_slow` is set to True.
|
||||
from_slow = kwargs.get("from_slow", False)
|
||||
gguf_file = kwargs.get("gguf_file", None)
|
||||
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
|
||||
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None:
|
||||
|
||||
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
|
||||
# loaded directly from the GGUF file.
|
||||
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file:
|
||||
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
|
||||
copy.deepcopy(resolved_vocab_files),
|
||||
pretrained_model_name_or_path,
|
||||
|
@ -29,6 +29,8 @@ from tokenizers.decoders import Decoder as DecoderFast
|
||||
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
|
||||
|
||||
from .convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from .integrations.ggml import convert_gguf_tokenizer
|
||||
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_base import (
|
||||
INIT_TOKENIZER_DOCSTRING,
|
||||
@ -94,6 +96,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
tokenizer_object = kwargs.pop("tokenizer_object", None)
|
||||
slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
|
||||
from_slow = kwargs.pop("from_slow", False)
|
||||
added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
|
||||
@ -112,6 +115,10 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
elif slow_tokenizer is not None:
|
||||
# We need to convert a slow tokenizer to build the backend
|
||||
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
|
||||
elif gguf_file is not None:
|
||||
# We need to convert a slow tokenizer to build the backend
|
||||
tokenizer_dict = load_gguf_checkpoint(kwargs.get("vocab_file"))["tokenizer"]
|
||||
fast_tokenizer = convert_gguf_tokenizer(tokenizer_dict)
|
||||
elif self.slow_tokenizer_class is not None:
|
||||
# We need to create and convert a slow tokenizer to build the backend
|
||||
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
|
||||
|
@ -129,6 +129,7 @@ from .import_utils import (
|
||||
is_ftfy_available,
|
||||
is_g2p_en_available,
|
||||
is_galore_torch_available,
|
||||
is_gguf_available,
|
||||
is_hqq_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
|
@ -152,6 +152,7 @@ _safetensors_available = _is_package_available("safetensors")
|
||||
_scipy_available = _is_package_available("scipy")
|
||||
_sentencepiece_available = _is_package_available("sentencepiece")
|
||||
_is_seqio_available = _is_package_available("seqio")
|
||||
_is_gguf_available = _is_package_available("gguf")
|
||||
_sklearn_available = importlib.util.find_spec("sklearn") is not None
|
||||
if _sklearn_available:
|
||||
try:
|
||||
@ -810,6 +811,10 @@ def is_seqio_available():
|
||||
return _is_seqio_available
|
||||
|
||||
|
||||
def is_gguf_available():
|
||||
return _is_gguf_available
|
||||
|
||||
|
||||
def is_protobuf_available():
|
||||
if importlib.util.find_spec("google") is None:
|
||||
return False
|
||||
|
0
tests/quantization/ggml/__init__.py
Normal file
0
tests/quantization/ggml/__init__.py
Normal file
215
tests/quantization/ggml/test_ggml.py
Normal file
215
tests/quantization/ggml/test_ggml.py
Normal file
@ -0,0 +1,215 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_gguf
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class GgufIntegrationTests(unittest.TestCase):
|
||||
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
|
||||
|
||||
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
|
||||
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
|
||||
q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
|
||||
q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf"
|
||||
q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
||||
q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
|
||||
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
|
||||
|
||||
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
|
||||
|
||||
example_text = "Hello"
|
||||
|
||||
def test_q2_k(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q2_k_serialization(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q3_k(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\n```\n<|user"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q5_k(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q4_k_m(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q6_k(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q6_k_fp16(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, gguf_file=self.q6_k_gguf_model_id, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model.lm_head.weight.dtype == torch.float16)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_q8_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id).to(torch_device)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_mistral_q4_0(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello,\n\nI'm trying to create a"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_tokenization_xnli(self):
|
||||
import tqdm
|
||||
from datasets import load_dataset
|
||||
|
||||
gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
|
||||
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
|
||||
|
||||
dataset = load_dataset("code_x_glue_ct_code_to_text", "go")
|
||||
for item in tqdm.tqdm(dataset["validation"]):
|
||||
string = item["code"]
|
||||
encoded1 = gguf_tokenizer.encode(string)
|
||||
encoded2 = original_tokenizer.encode(string)
|
||||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
||||
|
||||
dataset = load_dataset("xnli", "all_languages")
|
||||
|
||||
for i, item in enumerate(tqdm.tqdm(dataset["train"].select(range(100)))):
|
||||
for string in item["premise"].values():
|
||||
encoded1 = gguf_tokenizer.encode(string)
|
||||
encoded2 = original_tokenizer.encode(string)
|
||||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
||||
|
||||
# With special tokens
|
||||
gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
|
||||
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
|
||||
|
||||
gguf_tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
|
||||
)
|
||||
original_tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
|
||||
)
|
||||
|
||||
text = "Hello <token>. <token> Hello"
|
||||
|
||||
encoded1 = gguf_tokenizer.encode(text)
|
||||
encoded2 = original_tokenizer.encode(text)
|
||||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
@ -331,6 +331,7 @@ IGNORE_SUBMODULES = [
|
||||
"models.esm.openfold_utils",
|
||||
"modeling_attn_mask_utils",
|
||||
"safetensors_conversion",
|
||||
"modeling_gguf_pytorch_utils",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user