[Llama] Conversion: fix and simplify the script! (#31591)

* fix and simplify the script!

* add co-author

---------

Co-authored-by: crackalamoo <crackalamoo@users.noreply.github.com>
This commit is contained in:
Arthur 2024-06-27 12:35:19 +02:00 committed by GitHub
parent c9f191a0b7
commit 11138ca013
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -105,21 +105,18 @@ def write_json(text, path):
def write_model(
model_path,
input_base_path,
model_size,
model_size=None,
safe_serialization=True,
llama_version=1,
vocab_size=None,
num_shards=None,
):
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
@ -142,12 +139,13 @@ def write_model(
vocab_size = vocab_size if vocab_size is not None else 32000
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
num_key_value_heads_per_shard = num_key_value_heads // num_shards
key_value_dim = dims_per_head * num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
num_key_value_heads_per_shard = n_heads_per_shard
key_value_dim = dims_per_head * num_key_value_heads
print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim)
# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
@ -162,8 +160,9 @@ def write_model(
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
torch.load(os.path.join(input_base_path, file), map_location="cpu")
for file in os.listdir(input_base_path)
if file.endswith(".pth")
]
param_count = 0
index_dict = {"weight_map": {}}
@ -178,7 +177,7 @@ def write_model(
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"],
n_heads=num_key_value_heads,
dim1=dim // num_local_key_value_heads,
dim1=key_value_dim,
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
@ -206,7 +205,7 @@ def write_model(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
for i in range(len(loaded))
],
dim=0,
).reshape(dim, dim),
@ -216,9 +215,9 @@ def write_model(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
num_local_key_value_heads, dims_per_head, dim
num_key_value_heads_per_shard, dims_per_head, dim
)
for i in range(num_shards)
for i in range(len(loaded))
],
dim=0,
).reshape(key_value_dim, dim),
@ -229,24 +228,24 @@ def write_model(
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
num_local_key_value_heads, dims_per_head, dim
num_key_value_heads_per_shard, dims_per_head, dim
)
for i in range(num_shards)
for i in range(len(loaded))
],
dim=0,
).reshape(key_value_dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
@ -268,9 +267,9 @@ def write_model(
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=concat_dim
[loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0),
}
for k, v in state_dict.items():
@ -310,7 +309,7 @@ def write_model(
model.config.torch_dtype = torch.float16
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=safe_serialization)
shutil.rmtree(tmp_model_path)
shutil.rmtree(tmp_model_path, ignore_errors=True)
class Llama3Converter(TikTokenConverter):
@ -371,8 +370,8 @@ def main():
)
parser.add_argument(
"--model_size",
choices=["7B", "8B", "8Bf", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
default=None,
help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
)
parser.add_argument(
"--output_dir",
@ -389,7 +388,15 @@ def main():
type=int,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
parser.add_argument(
"--num_shards",
default=None,
type=int,
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
)
args = parser.parse_args()
if args.model_size is None and args.num_shards is None:
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
spm_path = os.path.join(args.input_dir, "tokenizer.model")
vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version))
if args.model_size != "tokenizer_only":
@ -400,6 +407,7 @@ def main():
safe_serialization=args.safe_serialization,
llama_version=args.llama_version,
vocab_size=vocab_size,
num_shards=args.num_shards,
)