Loading GGUF files support (#30391)

* Adds support for loading GGUF files

Co-authored-by: Younes Belkada <younesbelkada@gmail.com>
Co-authored-by: 99991 <99991@users.noreply.github.com>

* add q2_k q3_k q5_k support from @99991

* fix tests

* Update doc

* Style

* Docs

* fix CI

* Update docs/source/en/gguf.md

* Update docs/source/en/gguf.md

* Compute merges

* change logic

* add comment for clarity

* add comment for clarity

* Update src/transformers/models/auto/tokenization_auto.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change logic

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_gguf_pytorch_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* put back comment

* add comment about mistral

* comments and added tests

* fix unconsistent type

* more

* fix tokenizer

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* address comments about tests and tokenizer + add added_tokens

* from_gguf -> gguf_file

* replace on docs too

---------

Co-authored-by: Younes Belkada <younesbelkada@gmail.com>
Co-authored-by: 99991 <99991@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut 2024-05-15 14:28:20 +02:00 committed by GitHub
parent bd9f4d7951
commit a42844955f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1248 additions and 52 deletions

View File

@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
# For video model testing
RUN python3 -m pip install --no-cache-dir decord av==9.2.0
# For GGUF tests
RUN python3 -m pip install --no-cache-dir gguf
# Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes

View File

@ -137,6 +137,8 @@
title: Troubleshoot
- local: hf_quantizer
title: Contribute new quantization method
- local: gguf
title: Interoperability with GGUF files
title: Developer guides
- sections:
- local: performance

96
docs/source/en/gguf.md Normal file
View File

@ -0,0 +1,96 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# GGUF and interaction with Transformers
The GGUF file format is used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and other
libraries that depend on it, like the very popular [llama.cpp](https://github.com/ggerganov/llama.cpp) or
[whisper.cpp](https://github.com/ggerganov/whisper.cpp).
It is a file format [supported by the Hugging Face Hub](https://huggingface.co/docs/hub/en/gguf) with features
allowing for quick inspection of tensors and metadata within the file.
This file format is designed as a "single-file-format" where a single file usually contains both the configuration
attributes, the tokenizer vocabulary and other attributes, as well as all tensors to be loaded in the model. These
files come in different formats according to the quantization type of the file. We briefly go over some of them
[here](https://huggingface.co/docs/hub/en/gguf#quantization-types).
## Support within Transformers
We have added the ability to load `gguf` files within `transformers` in order to offer further training/fine-tuning
capabilities to gguf models, before converting back those models to `gguf` to use within the `ggml` ecosystem. When
loading a model, we first dequantize it to fp32, before loading the weights to be used in PyTorch.
> [!NOTE]
> The support is still very exploratory and we welcome contributions in order to solidify it across quantization types
> and model architectures.
For now, here are the supported model architectures and quantization types:
### Supported quantization types
The initial supported quantization types are decided according to the popular quantized files that have been shared
on the Hub.
- F32
- Q2_K
- Q3_K
- Q4_0
- Q4_K
- Q5_K
- Q6_K
- Q8_0
We take example from the excellent [99991/pygguf](https://github.com/99991/pygguf) Python parser to dequantize the
weights.
### Supported model architectures
For now the supported model architectures are the architectures that have been very popular on the Hub, namely:
- LLaMa
- Mistral
## Example usage
In order to load `gguf` files in `transformers`, you should specify the `gguf_file` argument to the `from_pretrained`
methods of both tokenizers and models. Here is how one would load a tokenizer and a model, which can be loaded
from the exact same file:
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
model = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename)
```
Now you have access to the full, unquantized version of the model in the PyTorch ecosystem, where you can combine it
with a plethora of other tools.
In order to convert back to a `gguf` file, we recommend using the
[`convert-hf-to-gguf.py` file](https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py) from llama.cpp.
Here's how you would complete the script above to save the model and export it back to `gguf`:
```py
tokenizer.save_pretrained('directory')
model.save_pretrained('directory')
!python ${path_to_llama_cpp}/convert-hf-to-gguf.py ${directory}
```

View File

@ -27,6 +27,7 @@ from packaging import version
from . import __version__
from .dynamic_module_utils import custom_object_save
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
from .utils import (
CONFIG_NAME,
PushToHubMixin,
@ -658,6 +659,8 @@ class PretrainedConfig(PushToHubMixin):
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
gguf_file = kwargs.get("gguf_file", None)
if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
@ -676,10 +679,10 @@ class PretrainedConfig(PushToHubMixin):
resolved_config_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
configuration_file = pretrained_model_name_or_path
configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
resolved_config_file = download_url(pretrained_model_name_or_path)
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
try:
# Load from local folder or from cache or download from model Hub and cache
@ -712,8 +715,12 @@ class PretrainedConfig(PushToHubMixin):
)
try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
if gguf_file:
config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
else:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(

View File

@ -44,6 +44,14 @@ _import_structure = {
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"ggml": [
"GGUF_CONFIG_MAPPING",
"GGUF_TENSOR_MAPPING",
"GGUF_TOKENIZER_MAPPING",
"_gguf_parse_value",
"load_dequant_gguf_tensor",
"load_gguf",
],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
@ -116,6 +124,14 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .ggml import (
GGUF_CONFIG_MAPPING,
GGUF_TENSOR_MAPPING,
GGUF_TOKENIZER_MAPPING,
_gguf_parse_value,
load_dequant_gguf_tensor,
load_gguf,
)
from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,

View File

@ -0,0 +1,584 @@
# coding=utf-8
# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
# https://github.com/99991/pygguf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integration with GGML / The file is copied and adapted from https://github.com/99991/pygguf
with extra methods beings exposed
"""
from array import array
import numpy as np
from tokenizers import Tokenizer, decoders
from tokenizers.models import BPE
from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter
from ..utils import logging
from ..utils.logging import tqdm
logger = logging.get_logger(__name__)
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
GGML_TYPES = {
"F32": 0,
"Q4_0": 2,
"Q8_0": 8,
"Q2_K": 10,
"Q3_K": 11,
"Q4_K": 12,
"Q5_K": 13,
"Q6_K": 14,
}
# The Blocksizes are reported in bytes
# Check out: https://github.com/ggerganov/llama.cpp/blob/8a56075b07a8b571bf95a912ffdce4c928c2b414/gguf-py/gguf/constants.py#L801
GGML_BLOCK_SIZES = {
"Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales
"Q4_K": 144,
# Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales
"Q4_0": 2 + 16,
"Q6_K": 210,
# See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
}
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
DATA_TYPES = {
"uint32": 4,
"int32": 5,
"float32": 6,
"bool": 7,
"string": 8,
"array": 9,
"uint64": 10,
}
GGUF_TENSOR_MAPPING = {
"llama": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"mistral": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}
GGUF_CONFIG_MAPPING = {
"general": {
"architecture": "model_type",
"name": "_model_name_or_path",
},
"llama": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"mistral": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.model": "model_type",
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
"ggml.unknown_token_id": "unk_token_id",
"ggml.padding_token_id": "pad_token_id",
},
}
GGUF_TOKENIZER_MAPPING = {
"tokenizer": {
"ggml.model": "tokenizer_type",
"ggml.tokens": "tokens",
"ggml.scores": "scores",
"ggml.token_type": "token_type",
"ggml.merges": "merges",
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
"ggml.unknown_token_id": "unk_token_id",
"ggml.padding_token_id": "pad_token_id",
"ggml.add_space_prefix": "add_prefix_space",
},
"tokenizer_config": {
"chat_template": "chat_template",
"ggml.model": "model_type",
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
"ggml.unknown_token_id": "unk_token_id",
"ggml.padding_token_id": "pad_token_id",
},
}
def _gguf_parse_value(_value, data_type):
if not isinstance(data_type, list):
data_type = [data_type]
if len(data_type) == 1:
data_type = data_type[0]
array_data_type = None
else:
if data_type[0] != 9:
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
data_type, array_data_type = data_type
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
_value = int(_value[0])
elif data_type in [6, 12]:
_value = float(_value[0])
elif data_type in [7]:
_value = bool(_value[0])
elif data_type in [8]:
_value = array("B", list(_value)).tobytes().decode()
elif data_type in [9]:
_value = _gguf_parse_value(_value, array_data_type)
return _value
def dequantize_q4_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate(
[qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1
)
offsets = scale_offsets * np.concatenate(
[qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1
)
# Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xF, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets
return factors * qs2 - offsets
def dequantize_q4_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11
block_size = GGML_BLOCK_SIZES["Q4_0"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# The scales are stored on the first 2 bytes and the rest corresponds to the quants
scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
# scales = np.nan_to_num(scales)
# the rest of the bytes corresponds to the quants - we discard the first two bytes
quants = data_u8[:, 2:]
ql = (quants[:, :] & 0xF).astype(np.int8) - 8
qr = (quants[:, :] >> 4).astype(np.int8) - 8
# Use hstack
quants = np.hstack([ql, qr])
return (scales * quants).astype(np.float32)
def dequantize_q6_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152
block_size = GGML_BLOCK_SIZES["Q6_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
# TODO use uint8 and cast later?
ql = data_u8[:, :128].astype(np.int16)
qh = data_u8[:, 128:192].astype(np.int16)
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
# Unpack bits, subtraction requires signed data type
q1 = (ql[:, :32] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
q2 = (ql[:, 32:64] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
q3 = (ql[:, :32] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
q4 = (ql[:, 32:64] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
# Dequantize
return scales * np.concatenate(
[
sc[:, 0] * q1[:, :16],
sc[:, 1] * q1[:, 16:],
sc[:, 2] * q2[:, :16],
sc[:, 3] * q2[:, 16:],
sc[:, 4] * q3[:, :16],
sc[:, 5] * q3[:, 16:],
sc[:, 6] * q4[:, :16],
sc[:, 7] * q4[:, 16:],
sc[:, 8] * q5[:, :16],
sc[:, 9] * q5[:, 16:],
sc[:, 10] * q6[:, :16],
sc[:, 11] * q6[:, 16:],
sc[:, 12] * q7[:, :16],
sc[:, 13] * q7[:, 16:],
sc[:, 14] * q8[:, :16],
sc[:, 15] * q8[:, 16:],
],
axis=1,
)
def dequantize_q8_0(data):
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
block_size = GGML_BLOCK_SIZES["Q8_0"]
num_blocks = len(data) // block_size
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs
def dequantize_q2_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74
num_blocks = len(data) // GGML_BLOCK_SIZES["Q2_K"]
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"])
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
tmp = np.stack(
[
qs[:, 00:16] >> 0,
qs[:, 16:32] >> 0,
qs[:, 00:16] >> 2,
qs[:, 16:32] >> 2,
qs[:, 00:16] >> 4,
qs[:, 16:32] >> 4,
qs[:, 00:16] >> 6,
qs[:, 16:32] >> 6,
qs[:, 32:48] >> 0,
qs[:, 48:64] >> 0,
qs[:, 32:48] >> 2,
qs[:, 48:64] >> 2,
qs[:, 32:48] >> 4,
qs[:, 48:64] >> 4,
qs[:, 32:48] >> 6,
qs[:, 48:64] >> 6,
],
axis=1,
)
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q3_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95
num_blocks = len(data) // GGML_BLOCK_SIZES["Q3_K"]
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"])
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
bits = 4 ^ (bits << 2)
qs = data_u8[:, 32 : 32 + 64].astype(np.int16)
a, b, c = data_u8[:, 96 : 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
scales[:, 0] = (a & 15) | ((c & 3) << 4)
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
return (
d
* (scales - 32)
* np.stack(
[
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]),
],
axis=1,
)
)
def dequantize_q5_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_K"]
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"])
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qh = data_u8[:, 16 : 16 + 32].reshape(num_blocks, 32, 1)
qs = data_u8[:, 48 : 48 + 128].reshape(num_blocks, 4, 32)
bits = np.unpackbits(qh, axis=-1, bitorder="little")
qs_hi_4 = qs >> 4
qs_lo_4 = qs & 15
scales_lo_6 = scales[:, :8] & 63
scales_hi_6 = scales[:, :8] >> 6
scales_lo_4 = scales[:, 8:] & 15
scales_hi_4 = scales[:, 8:] >> 4
m1 = dmin * scales_lo_6[:, 4]
m2 = dmin * scales_lo_6[:, 5]
m3 = dmin * scales_lo_6[:, 6]
m4 = dmin * scales_lo_6[:, 7]
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
d1 = d * scales_lo_6[:, 0]
d2 = d * scales_lo_6[:, 1]
d3 = d * scales_lo_6[:, 2]
d4 = d * scales_lo_6[:, 3]
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
return np.concatenate(
[
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
],
axis=1,
)
def load_dequant_gguf_tensor(shape, ggml_type, data):
if ggml_type == GGML_TYPES["F32"]:
values = data
elif ggml_type == GGML_TYPES["Q8_0"]:
values = dequantize_q8_0(data)
elif ggml_type == GGML_TYPES["Q4_0"]:
values = dequantize_q4_0(data)
elif ggml_type == GGML_TYPES["Q4_K"]:
values = dequantize_q4_k(data)
elif ggml_type == GGML_TYPES["Q6_K"]:
values = dequantize_q6_k(data)
elif ggml_type == GGML_TYPES["Q2_K"]:
values = dequantize_q2_k(data)
elif ggml_type == GGML_TYPES["Q3_K"]:
values = dequantize_q3_k(data)
elif ggml_type == GGML_TYPES["Q5_K"]:
values = dequantize_q5_k(data)
else:
raise NotImplementedError(
f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose"
)
return values.reshape(shape[::-1])
class GGUFTokenizerSkeleton:
def __init__(self, dict_):
for k, v in dict_.items():
setattr(self, k, v)
if not hasattr(self, "tokens") or not hasattr(self, "scores"):
raise ValueError("tokens and scores need to be passed for a LLaMa tokenizer to be instantiated.")
else:
tokens = self.tokens
scores = self.scores
vocab = {t: scores[i] for i, t in enumerate(tokens)}
if not hasattr(self, "merges"):
logger.warning("Merges were not in checkpoint, building merges on the fly.")
merges = []
for merge, piece_score in tqdm(vocab.items()):
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in tokens and piece_r in tokens:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]), reverse=True)
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=True)
merges = [(val[0], val[1]) for val in merges]
self.merges = merges
else:
self.merges = [tuple(merge.split(" ")) for merge in self.merges]
if not hasattr(self, "added_tokens"):
self.added_tokens = []
class GGUFLlamaConverter(LlamaConverter):
def __init__(self, tokenizer_dict):
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
self.original_tokenizer = self.proto
def vocab(self, proto):
return list(zip(proto.tokens, proto.scores))
def merges(self, proto):
return proto.merges
def tokenizer(self, proto):
vocab_scores = self.vocab(self.proto)
merges = self.merges(self.proto)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.tokens[proto.unk_token_id], fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True),
]
)
if len(self.proto.added_tokens) != 0:
tokenizer.add_special_tokens(
[AddedToken(added_token, normalized=False, special=False) for added_token in self.added_tokens]
)
return tokenizer
def decoder(self, replacement, add_prefix_space):
sequence = [
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Replace("", " "),
]
if add_prefix_space:
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)
GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
}
def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
Args:
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
Instance of a slow tokenizer to convert in the backend tokenizer for
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
Return:
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""
tokenizer_class_name = tokenizer_dict["tokenizer_type"]
converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(tokenizer_dict).converted()

View File

@ -0,0 +1,165 @@
# coding=utf-8
# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
# https://github.com/99991/pygguf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from tqdm import tqdm
from .integrations import (
GGUF_CONFIG_MAPPING,
GGUF_TENSOR_MAPPING,
GGUF_TOKENIZER_MAPPING,
_gguf_parse_value,
load_dequant_gguf_tensor,
)
from .utils import is_torch_available
from .utils.logging import get_logger
if is_torch_available():
import torch
logger = get_logger(__name__)
GGUF_TO_TRANSFORMERS_MAPPING = {
"ignore": {
"GGUF": {
"version": "version",
"tensor_count": "tensor_count",
"kv_count": "kv_count",
},
"general": {"file_type": "file_type", "quantization_version": "quantization_version"},
},
"config": GGUF_CONFIG_MAPPING,
"tensors": GGUF_TENSOR_MAPPING,
"tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]},
"tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]},
}
GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())
def read_field(reader, field):
value = reader.fields[field]
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
"""
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed
tokenizer and config attributes.
Args:
gguf_checkpoint_path (`str`):
The path the to GGUF file to load
return_tensors (`bool`, defaults to `True`):
Whether to read the tensors from the file and return them. Not doing so is faster
and only loads the metadata in memory.
"""
try:
from gguf import GGUFReader
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF to be installed. Please see "
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise
reader = GGUFReader(gguf_checkpoint_path)
fields = reader.fields
reader_keys = list(fields.keys())
parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING}
architecture = read_field(reader, "general.architecture")[0]
model_name = read_field(reader, "general.name")
# in llama.cpp mistral models use the same architecture as llama. We need
# to add this patch to ensure things work correctly on our side.
if "llama" in architecture and "mistral" in model_name:
updated_architecture = "mistral"
else:
updated_architecture = architecture
if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture} not supported")
# List all key-value pairs in a columnized format
for gguf_key, field in reader.fields.items():
gguf_key = gguf_key.replace(architecture, updated_architecture)
split = gguf_key.split(".")
prefix = split[0]
config_key = ".".join(split[1:])
value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data]
if len(value) == 1:
value = value[0]
if isinstance(value, str) and architecture in value:
value = value.replace(architecture, updated_architecture)
for parameter in GGUF_TO_TRANSFORMERS_MAPPING:
parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter]
if prefix in parameter_renames and config_key in parameter_renames[prefix]:
renamed_config_key = parameter_renames[prefix][config_key]
if renamed_config_key == -1:
continue
if renamed_config_key is not None:
parsed_parameters[parameter][renamed_config_key] = value
if gguf_key in reader_keys:
reader_keys.remove(gguf_key)
if gguf_key in reader_keys:
logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
if return_tensors:
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture]
for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
renamed_tensor_name = tensor.name
for tensor_name_mapping in GGUF_TO_TRANSFORMERS_MAPPING["tensors"]:
if tensor_name_mapping in renamed_tensor_name:
renamed_tensor_name = renamed_tensor_name.replace(
tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping]
)
shape = tensor.shape
name = tensor.name
weights = load_dequant_gguf_tensor(shape=shape, ggml_type=tensor.tensor_type, data=tensor.data)
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
weights = weights.reshape(tmp_shape)
weights = weights.transpose(0, 2, 1, 3)
weights = weights.reshape(shape[::-1])
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
# Use copy to avoid errors with numpy and pytorch
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
if len(reader_keys) > 0:
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
return parsed_parameters

View File

@ -2993,6 +2993,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
gguf_file = kwargs.pop("gguf_file", None)
# Cache path to the GGUF file
gguf_path = None
if is_fsdp_enabled():
low_cpu_mem_usage = True
@ -3156,6 +3160,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
pre_quantized = getattr(config, "quantization_config", None) is not None
@ -3197,7 +3202,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules = None
use_keep_in_fp32_modules = False
if pretrained_model_name_or_path is not None:
if gguf_file is not None and hf_quantizer is not None:
raise ValueError(
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
)
if pretrained_model_name_or_path is not None and gguf_file is None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
@ -3439,6 +3449,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
elif gguf_file:
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
# Case 1: the GGUF file is present locally
if os.path.isfile(gguf_file):
gguf_path = gguf_file
# Case 2: The GGUF path is a location on the Hub
# Load from URL or cache if already cached
else:
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"token": token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"]
resolved_archive_file = None
is_sharded = False
else:
resolved_archive_file = None
@ -3533,7 +3573,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = list(state_dict.keys())
if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()):
if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
# In case some weights need to be kept in float32 and accelerate is not installed,
# we later on want to take the path where state_dict is not None, that is the one
# that do not require accelerate.
@ -3679,6 +3720,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
(
model,
missing_keys,
@ -3702,6 +3744,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
)
# make sure token embedding weights are still tied if needed
@ -3795,9 +3838,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
gguf_path=None,
):
is_safetensors = False
is_quantized = hf_quantizer is not None
state_dict_folder = None
state_dict_index = None
if device_map is not None and "disk" in device_map.values():
archive_file = (
@ -4055,6 +4101,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
else:
offload_index = None
if state_dict is not None:
# Whole checkpoint
@ -4066,11 +4114,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
offload_index = None
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
# For GGUF models `state_dict` is never set to None as the state dict is always small
if gguf_path:
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]

View File

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...utils import (
@ -781,6 +782,7 @@ class AutoTokenizer:
use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
gguf_file = kwargs.get("gguf_file", None)
# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
@ -827,9 +829,14 @@ class AutoTokenizer:
# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
if gguf_file:
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
config = AutoConfig.for_model(**config_dict)
else:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
config_tokenizer_class = config.tokenizer_class
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
@ -887,6 +894,7 @@ class AutoTokenizer:
model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:

View File

@ -74,6 +74,7 @@ from .utils import (
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
@ -376,6 +377,13 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
)(test_case)
def require_gguf(test_case):
"""
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
"""
return unittest.skipUnless(is_gguf_available(), "test requires gguf")(test_case)
def require_fsdp(test_case, min_version: str = "1.12.0"):
"""
Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.

View File

@ -1968,6 +1968,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
gguf_file = kwargs.get("gguf_file", None)
if use_auth_token is not None:
warnings.warn(
@ -1995,7 +1996,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
is_local = os.path.isdir(pretrained_model_name_or_path)
single_file_id = None
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
if len(cls.vocab_files_names) > 1 and not gguf_file:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
"supported for this tokenizer. Use a model identifier or the path to a directory instead."
@ -2010,42 +2011,45 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
vocab_files[file_id] = pretrained_model_name_or_path
single_file_id = file_id
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
"tokenizer_file": FULL_TOKENIZER_FILE,
}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
vocab_files["tokenizer_file"] = fast_tokenizer_file
if gguf_file:
vocab_files["vocab_file"] = gguf_file
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
"tokenizer_file": FULL_TOKENIZER_FILE,
}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
vocab_files["tokenizer_file"] = fast_tokenizer_file
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
@ -2084,7 +2088,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"files are necessary for the tokenizer to operate."
)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
# loaded directly from the GGUF file.
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not gguf_file:
raise EnvironmentError(
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
@ -2133,8 +2139,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
# file or if `from_slow` is set to True.
from_slow = kwargs.get("from_slow", False)
gguf_file = kwargs.get("gguf_file", None)
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None:
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
# loaded directly from the GGUF file.
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file:
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
copy.deepcopy(resolved_vocab_files),
pretrained_model_name_or_path,

View File

@ -29,6 +29,8 @@ from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
from .convert_slow_tokenizer import convert_slow_tokenizer
from .integrations.ggml import convert_gguf_tokenizer
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import (
INIT_TOKENIZER_DOCSTRING,
@ -94,6 +96,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def __init__(self, *args, **kwargs):
tokenizer_object = kwargs.pop("tokenizer_object", None)
slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
gguf_file = kwargs.pop("gguf_file", None)
fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
from_slow = kwargs.pop("from_slow", False)
added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
@ -112,6 +115,10 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
elif slow_tokenizer is not None:
# We need to convert a slow tokenizer to build the backend
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
elif gguf_file is not None:
# We need to convert a slow tokenizer to build the backend
tokenizer_dict = load_gguf_checkpoint(kwargs.get("vocab_file"))["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(tokenizer_dict)
elif self.slow_tokenizer_class is not None:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)

View File

@ -129,6 +129,7 @@ from .import_utils import (
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,

View File

@ -152,6 +152,7 @@ _safetensors_available = _is_package_available("safetensors")
_scipy_available = _is_package_available("scipy")
_sentencepiece_available = _is_package_available("sentencepiece")
_is_seqio_available = _is_package_available("seqio")
_is_gguf_available = _is_package_available("gguf")
_sklearn_available = importlib.util.find_spec("sklearn") is not None
if _sklearn_available:
try:
@ -810,6 +811,10 @@ def is_seqio_available():
return _is_seqio_available
def is_gguf_available():
return _is_gguf_available
def is_protobuf_available():
if importlib.util.find_spec("google") is None:
return False

View File

View File

@ -0,0 +1,215 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
from transformers.utils import is_torch_available
if is_torch_available():
import torch
@require_gguf
@require_torch_gpu
@slow
class GgufIntegrationTests(unittest.TestCase):
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf"
q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
example_text = "Hello"
def test_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q2_k_serialization(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
tokenizer.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(tmpdirname)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q3_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n```\n<|user"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q5_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q4_k_m(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q6_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q6_k_fp16(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.model_id, gguf_file=self.q6_k_gguf_model_id, torch_dtype=torch.float16
).to(torch_device)
self.assertTrue(model.lm_head.weight.dtype == torch.float16)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q8_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello,\n\nI'm trying to create a"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
dataset = load_dataset("code_x_glue_ct_code_to_text", "go")
for item in tqdm.tqdm(dataset["validation"]):
string = item["code"]
encoded1 = gguf_tokenizer.encode(string)
encoded2 = original_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
dataset = load_dataset("xnli", "all_languages")
for i, item in enumerate(tqdm.tqdm(dataset["train"].select(range(100)))):
for string in item["premise"].values():
encoded1 = gguf_tokenizer.encode(string)
encoded2 = original_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
# With special tokens
gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
gguf_tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
)
original_tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
)
text = "Hello <token>. <token> Hello"
encoded1 = gguf_tokenizer.encode(text)
encoded2 = original_tokenizer.encode(text)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)

View File

@ -331,6 +331,7 @@ IGNORE_SUBMODULES = [
"models.esm.openfold_utils",
"modeling_attn_mask_utils",
"safetensors_conversion",
"modeling_gguf_pytorch_utils",
]