mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix the incorrect permutation of gguf (#31788)
* Fix the incorrect permutation of gguf * rename num_kv_heads Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * add typing to num_kv_heads Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * rename variables * refactor permute function name * update the expected text of the llama3 q4 test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
6fbea6d237
commit
ac946aac25
@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -147,10 +149,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
|
||||
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
|
||||
num_heads = parsed_parameters["config"]["num_attention_heads"]
|
||||
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
|
||||
weights = weights.reshape(tmp_shape)
|
||||
weights = weights.transpose(0, 2, 1, 3)
|
||||
weights = weights.reshape(shape[::-1])
|
||||
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
|
||||
if ".attn_q." in name:
|
||||
weights = reverse_permute_weights(weights, num_heads, num_heads)
|
||||
elif ".attn_k." in name:
|
||||
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
|
||||
|
||||
for tensor_name in tensor_key_mapping:
|
||||
if tensor_name in name:
|
||||
@ -163,3 +166,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
|
||||
|
||||
return parsed_parameters
|
||||
|
||||
|
||||
def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
|
||||
# Original permutation implementation
|
||||
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
|
||||
if num_kv_heads is not None and n_head != num_kv_heads:
|
||||
n_head = num_kv_heads
|
||||
|
||||
dim = weights.shape[0] // n_head // 2
|
||||
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
|
||||
return w.swapaxes(2, 1).reshape(weights.shape)
|
||||
|
@ -188,8 +188,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, I am new to this forum. I am"
|
||||
|
||||
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_tokenization_xnli(self):
|
||||
|
Loading…
Reference in New Issue
Block a user