Llama conversion script: adjustments for Llama Guard (#27910)

This commit is contained in:
Pedro Cuenca 2023-12-08 16:02:50 +01:00 committed by GitHub
parent e366937587
commit e0b617d192
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
@ -109,7 +110,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
tokenizer.save_pretrained(model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
if "n_kv_heads" in params:
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads