From e0b617d1928a83e83c3aa5506b7509b0268ff202 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 8 Dec 2023 16:02:50 +0100 Subject: [PATCH] Llama conversion script: adjustments for Llama Guard (#27910) --- src/transformers/models/llama/convert_llama_weights_to_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index 8b5930779c9..d2fc3a79aff 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa params = read_json(os.path.join(input_base_path, "params.json")) num_shards = NUM_SHARDS[model_size] + params = params.get("model", params) n_layers = params["n_layers"] n_heads = params["n_heads"] n_heads_per_shard = n_heads // num_shards @@ -109,7 +110,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa tokenizer.save_pretrained(model_path) vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 - if "n_kv_heads" in params: + if params.get("n_kv_heads", None) is not None: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA num_local_key_value_heads = n_heads_per_shard // num_key_value_heads key_value_dim = dim // num_key_value_heads