🚨 Support dequantization for most GGML types (#32625)

* use gguf internal dequantize

* add Q5_0 test

* add iq1 test

* add remained test

* remove duplicated test

* update docs

* add gguf version limit

* make style

* update gguf import catch

* revert vocab_size patch

* make style

* use GGUF_MIN_VERSION everywhere
This commit is contained in:
Isotr0py 2024-09-03 18:58:14 +08:00 committed by GitHub
parent 979f4774f6
commit edeca4387c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 169 additions and 356 deletions

View File

@ -46,16 +46,30 @@ The initial supported quantization types are decided according to the popular qu
on the Hub.
- F32
- F16
- BF16
- Q4_0
- Q4_1
- Q5_0
- Q5_1
- Q8_0
- Q2_K
- Q3_K
- Q4_0
- Q4_K
- Q5_K
- Q6_K
- Q8_0
- IQ1_S
- IQ1_M
- IQ2_XXS
- IQ2_XS
- IQ2_S
- IQ3_XXS
- IQ3_S
- IQ4_XS
- IQ4_NL
We take example from the excellent [99991/pygguf](https://github.com/99991/pygguf) Python parser to dequantize the
weights.
> [!NOTE]
> To support gguf dequantization, `gguf>=0.10.0` installation is required.
### Supported model architectures

View File

@ -33,44 +33,6 @@ from ..utils.logging import tqdm
logger = logging.get_logger(__name__)
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
GGML_TYPES = {
"F32": 0,
"F16": 1,
"Q4_0": 2,
"Q8_0": 8,
"Q2_K": 10,
"Q3_K": 11,
"Q4_K": 12,
"Q5_K": 13,
"Q6_K": 14,
}
# The Blocksizes are reported in bytes
# Check out: https://github.com/ggerganov/llama.cpp/blob/8a56075b07a8b571bf95a912ffdce4c928c2b414/gguf-py/gguf/constants.py#L801
GGML_BLOCK_SIZES = {
"Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales
"Q4_K": 144,
# Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales
"Q4_0": 2 + 16,
"Q6_K": 210,
# See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
}
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
DATA_TYPES = {
"uint32": 4,
"int32": 5,
"float32": 6,
"bool": 7,
"string": 8,
"array": 9,
"uint64": 10,
}
GGUF_TENSOR_MAPPING = {
"llama": {
"token_embd": "model.embed_tokens",
@ -217,303 +179,6 @@ def _gguf_parse_value(_value, data_type):
return _value
def dequantize_q4_k(data, n_bytes: int):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = n_bytes // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate(
[qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1
)
offsets = scale_offsets * np.concatenate(
[qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1
)
# Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xF, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets
return factors * qs2 - offsets
def dequantize_q4_0(data, n_bytes: int):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11
block_size = GGML_BLOCK_SIZES["Q4_0"]
num_blocks = n_bytes // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# The scales are stored on the first 2 bytes and the rest corresponds to the quants
scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
# scales = np.nan_to_num(scales)
# the rest of the bytes corresponds to the quants - we discard the first two bytes
quants = data_u8[:, 2:]
ql = (quants[:, :] & 0xF).astype(np.int8) - 8
qr = (quants[:, :] >> 4).astype(np.int8) - 8
# Use hstack
quants = np.hstack([ql, qr])
return (scales * quants).astype(np.float32)
def dequantize_q6_k(data, n_bytes: int):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152
block_size = GGML_BLOCK_SIZES["Q6_K"]
num_blocks = n_bytes // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
# TODO use uint8 and cast later?
ql = data_u8[:, :128].astype(np.int16)
qh = data_u8[:, 128:192].astype(np.int16)
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
# Unpack bits, subtraction requires signed data type
q1 = (ql[:, :32] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
q2 = (ql[:, 32:64] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
q3 = (ql[:, :32] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
q4 = (ql[:, 32:64] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
# Dequantize
return scales * np.concatenate(
[
sc[:, 0] * q1[:, :16],
sc[:, 1] * q1[:, 16:],
sc[:, 2] * q2[:, :16],
sc[:, 3] * q2[:, 16:],
sc[:, 4] * q3[:, :16],
sc[:, 5] * q3[:, 16:],
sc[:, 6] * q4[:, :16],
sc[:, 7] * q4[:, 16:],
sc[:, 8] * q5[:, :16],
sc[:, 9] * q5[:, 16:],
sc[:, 10] * q6[:, :16],
sc[:, 11] * q6[:, 16:],
sc[:, 12] * q7[:, :16],
sc[:, 13] * q7[:, 16:],
sc[:, 14] * q8[:, :16],
sc[:, 15] * q8[:, 16:],
],
axis=1,
)
def dequantize_q8_0(data, n_bytes: int):
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
block_size = GGML_BLOCK_SIZES["Q8_0"]
num_blocks = n_bytes // block_size
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs
def dequantize_q2_k(data, n_bytes: int):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74
num_blocks = n_bytes // GGML_BLOCK_SIZES["Q2_K"]
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"])
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
tmp = np.stack(
[
qs[:, 00:16] >> 0,
qs[:, 16:32] >> 0,
qs[:, 00:16] >> 2,
qs[:, 16:32] >> 2,
qs[:, 00:16] >> 4,
qs[:, 16:32] >> 4,
qs[:, 00:16] >> 6,
qs[:, 16:32] >> 6,
qs[:, 32:48] >> 0,
qs[:, 48:64] >> 0,
qs[:, 32:48] >> 2,
qs[:, 48:64] >> 2,
qs[:, 32:48] >> 4,
qs[:, 48:64] >> 4,
qs[:, 32:48] >> 6,
qs[:, 48:64] >> 6,
],
axis=1,
)
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q3_k(data, n_bytes: int):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95
num_blocks = n_bytes // GGML_BLOCK_SIZES["Q3_K"]
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"])
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
bits = 4 ^ (bits << 2)
qs = data_u8[:, 32 : 32 + 64].astype(np.int16)
a, b, c = data_u8[:, 96 : 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
scales[:, 0] = (a & 15) | ((c & 3) << 4)
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
return (
d
* (scales - 32)
* np.stack(
[
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]),
],
axis=1,
)
)
def dequantize_q5_k(data, n_bytes: int):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138
num_blocks = n_bytes // GGML_BLOCK_SIZES["Q5_K"]
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"])
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qh = data_u8[:, 16 : 16 + 32].reshape(num_blocks, 32, 1)
qs = data_u8[:, 48 : 48 + 128].reshape(num_blocks, 4, 32)
bits = np.unpackbits(qh, axis=-1, bitorder="little")
qs_hi_4 = qs >> 4
qs_lo_4 = qs & 15
scales_lo_6 = scales[:, :8] & 63
scales_hi_6 = scales[:, :8] >> 6
scales_lo_4 = scales[:, 8:] & 15
scales_hi_4 = scales[:, 8:] >> 4
m1 = dmin * scales_lo_6[:, 4]
m2 = dmin * scales_lo_6[:, 5]
m3 = dmin * scales_lo_6[:, 6]
m4 = dmin * scales_lo_6[:, 7]
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
d1 = d * scales_lo_6[:, 0]
d2 = d * scales_lo_6[:, 1]
d3 = d * scales_lo_6[:, 2]
d4 = d * scales_lo_6[:, 3]
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
return np.concatenate(
[
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
],
axis=1,
)
def load_dequant_gguf_tensor(shape, ggml_type, data, n_bytes):
if ggml_type == GGML_TYPES["F32"]:
values = data
elif ggml_type == GGML_TYPES["F16"]:
values = data
elif ggml_type == GGML_TYPES["Q8_0"]:
values = dequantize_q8_0(data, n_bytes)
elif ggml_type == GGML_TYPES["Q4_0"]:
values = dequantize_q4_0(data, n_bytes)
elif ggml_type == GGML_TYPES["Q4_K"]:
values = dequantize_q4_k(data, n_bytes)
elif ggml_type == GGML_TYPES["Q6_K"]:
values = dequantize_q6_k(data, n_bytes)
elif ggml_type == GGML_TYPES["Q2_K"]:
values = dequantize_q2_k(data, n_bytes)
elif ggml_type == GGML_TYPES["Q3_K"]:
values = dequantize_q3_k(data, n_bytes)
elif ggml_type == GGML_TYPES["Q5_K"]:
values = dequantize_q5_k(data, n_bytes)
else:
raise NotImplementedError(
f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose"
)
return values.reshape(shape[::-1])
class GGUFTokenizerSkeleton:
def __init__(self, dict_):
for k, v in dict_.items():

View File

@ -24,9 +24,9 @@ from .integrations import (
GGUF_TENSOR_MAPPING,
GGUF_TOKENIZER_MAPPING,
_gguf_parse_value,
load_dequant_gguf_tensor,
)
from .utils import is_torch_available
from .utils.import_utils import is_gguf_available
from .utils.logging import get_logger
@ -71,14 +71,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
Whether to read the tensors from the file and return them. Not doing so is faster
and only loads the metadata in memory.
"""
try:
from gguf import GGUFReader
except (ImportError, ModuleNotFoundError):
if is_gguf_available() and is_torch_available():
from gguf import GGUFReader, dequantize
else:
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF to be installed. Please see "
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
reader = GGUFReader(gguf_checkpoint_path)
fields = reader.fields
@ -154,12 +154,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping]
)
shape = tensor.shape
name = tensor.name
weights = load_dequant_gguf_tensor(
shape=shape, ggml_type=tensor.tensor_type, data=tensor.data, n_bytes=int(tensor.n_bytes)
)
weights = dequantize(tensor.data, tensor.tensor_type)
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]

View File

@ -53,6 +53,7 @@ from .integrations import (
from .integrations.deepspeed import is_deepspeed_available
from .utils import (
ACCELERATE_MIN_VERSION,
GGUF_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
@ -407,11 +408,13 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
)(test_case)
def require_gguf(test_case):
def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
"""
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
"""
return unittest.skipUnless(is_gguf_available(), "test requires gguf")(test_case)
return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
test_case
)
def require_fsdp(test_case, min_version: str = "1.12.0"):

View File

@ -99,6 +99,7 @@ from .import_utils import (
ACCELERATE_MIN_VERSION,
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
GGUF_MIN_VERSION,
TORCH_FX_REQUIRED_VERSION,
USE_JAX,
USE_TF,

View File

@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
ACCELERATE_MIN_VERSION = "0.26.0"
FSDP_MIN_VERSION = "1.12.0"
GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
@ -156,7 +157,7 @@ _safetensors_available = _is_package_available("safetensors")
_scipy_available = _is_package_available("scipy")
_sentencepiece_available = _is_package_available("sentencepiece")
_is_seqio_available = _is_package_available("seqio")
_is_gguf_available = _is_package_available("gguf")
_is_gguf_available, _gguf_version = _is_package_available("gguf", return_version=True)
_sklearn_available = importlib.util.find_spec("sklearn") is not None
if _sklearn_available:
try:
@ -914,8 +915,8 @@ def is_seqio_available():
return _is_seqio_available
def is_gguf_available():
return _is_gguf_available
def is_gguf_available(min_version: str = GGUF_MIN_VERSION):
return _is_gguf_available and version.parse(_gguf_version) >= version.parse(min_version)
def is_protobuf_available():

View File

@ -30,18 +30,32 @@ if is_torch_available():
class GgufIntegrationTests(unittest.TestCase):
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
q5_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_0.gguf"
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
# k-quants
q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
# imatrix
iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf"
iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf"
iq2_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_S.gguf"
iq2_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XS.gguf"
iq2_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XXS.gguf"
iq3_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_S.gguf"
iq3_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_XXS.gguf"
iq4_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf"
iq4_nl_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_NL.gguf"
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
@ -87,6 +101,16 @@ class GgufIntegrationTests(unittest.TestCase):
EXPECTED_TEXT = "Hello, World!\n\n```\n<|user"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q5_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q5_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device)
@ -151,6 +175,114 @@ class GgufIntegrationTests(unittest.TestCase):
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq1_s(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I'm a friend of mine, I"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq1_m(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I am interested in purching a copy of"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq2_s(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello World!\n\n```\n<|user|"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq2_xs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello World!\n\n```\n<|user|"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq2_xxs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I'm a software engineer. I'"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq3_s(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq3_xxs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I am interested in your product. Can you"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq4_xs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq4_nl(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_f16(self):
tokenizer = AutoTokenizer.from_pretrained(self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id)
model = AutoModelForCausalLM.from_pretrained(