mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Llama conversion script: adjustments for Llama Guard (#27910)
This commit is contained in:
parent
e366937587
commit
e0b617d192
@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
@ -109,7 +110,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
tokenizer.save_pretrained(model_path)
|
||||
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
|
||||
|
||||
if "n_kv_heads" in params:
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
||||
key_value_dim = dim // num_key_value_heads
|
||||
|
Loading…
Reference in New Issue
Block a user