mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Llama
] Conversion: fix and simplify the script! (#31591)
* fix and simplify the script! * add co-author --------- Co-authored-by: crackalamoo <crackalamoo@users.noreply.github.com>
This commit is contained in:
parent
c9f191a0b7
commit
11138ca013
@ -105,21 +105,18 @@ def write_json(text, path):
|
||||
def write_model(
|
||||
model_path,
|
||||
input_base_path,
|
||||
model_size,
|
||||
model_size=None,
|
||||
safe_serialization=True,
|
||||
llama_version=1,
|
||||
vocab_size=None,
|
||||
num_shards=None,
|
||||
):
|
||||
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
|
||||
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
|
||||
input_base_path = os.path.join(input_base_path, model_size)
|
||||
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
os.makedirs(tmp_model_path, exist_ok=True)
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
@ -142,12 +139,13 @@ def write_model(
|
||||
vocab_size = vocab_size if vocab_size is not None else 32000
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
||||
key_value_dim = dim // num_key_value_heads
|
||||
num_key_value_heads_per_shard = num_key_value_heads // num_shards
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_local_key_value_heads = n_heads_per_shard
|
||||
key_value_dim = dim
|
||||
num_key_value_heads_per_shard = n_heads_per_shard
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim)
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
@ -162,8 +160,9 @@ def write_model(
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
torch.load(os.path.join(input_base_path, file), map_location="cpu")
|
||||
for file in os.listdir(input_base_path)
|
||||
if file.endswith(".pth")
|
||||
]
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
@ -178,7 +177,7 @@ def write_model(
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=dim // num_local_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
@ -206,7 +205,7 @@ def write_model(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim),
|
||||
@ -216,9 +215,9 @@ def write_model(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
num_key_value_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim),
|
||||
@ -229,24 +228,24 @@ def write_model(
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
num_key_value_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
@ -268,9 +267,9 @@ def write_model(
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=concat_dim
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim
|
||||
),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0),
|
||||
}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
@ -310,7 +309,7 @@ def write_model(
|
||||
model.config.torch_dtype = torch.float16
|
||||
print("Saving in the Transformers format.")
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
shutil.rmtree(tmp_model_path)
|
||||
shutil.rmtree(tmp_model_path, ignore_errors=True)
|
||||
|
||||
|
||||
class Llama3Converter(TikTokenConverter):
|
||||
@ -371,8 +370,8 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "8B", "8Bf", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
|
||||
help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
|
||||
default=None,
|
||||
help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@ -389,7 +388,15 @@ def main():
|
||||
type=int,
|
||||
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.model_size is None and args.num_shards is None:
|
||||
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
|
||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||
vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version))
|
||||
if args.model_size != "tokenizer_only":
|
||||
@ -400,6 +407,7 @@ def main():
|
||||
safe_serialization=args.safe_serialization,
|
||||
llama_version=args.llama_version,
|
||||
vocab_size=vocab_size,
|
||||
num_shards=args.num_shards,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user