mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Support Llama 3.2 conversion (text models) (#33778)
* Support Llama 3.2 conversion (text models) Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Fix rope factor * Update chat template Initialize from a well-known template. The guidance is that the changes should be applied to 3.1 models as well. * Remove import * Support Llama Guard 3 conversion * Tokenizer details * Fix eos added token in base models * Fix generation config for base models * Specify revision for known tokenizers * Style * Reuse chat templates for older models * Improve error when converting tokenizer < Llama 3 --------- Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
This commit is contained in:
parent
c1c7e89620
commit
7a06d07e14
@ -15,11 +15,12 @@ import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tokenizers import AddedToken, processors
|
||||
|
||||
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
|
||||
from transformers.convert_slow_tokenizer import TikTokenConverter
|
||||
@ -39,7 +40,7 @@ Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 1B --llama_version 3.2 --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
@ -75,6 +76,8 @@ tokenizer._tokenizers.post_processor = processors.Sequence(
|
||||
"""
|
||||
|
||||
NUM_SHARDS = {
|
||||
"1B": 1,
|
||||
"3B": 1,
|
||||
"7B": 1,
|
||||
"8B": 1,
|
||||
"8Bf": 1,
|
||||
@ -90,284 +93,17 @@ NUM_SHARDS = {
|
||||
"405B-MP16": 16,
|
||||
}
|
||||
|
||||
CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(
|
||||
model_path,
|
||||
input_base_path,
|
||||
model_size=None,
|
||||
safe_serialization=True,
|
||||
llama_version="1",
|
||||
vocab_size=None,
|
||||
num_shards=None,
|
||||
instruct=False,
|
||||
):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
os.makedirs(tmp_model_path, exist_ok=True)
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = params.get("rope_theta", 10000.0)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
if base > 10000.0 and float(llama_version) < 3:
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_key_value_heads_per_shard = num_key_value_heads // num_shards
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_key_value_heads_per_shard = n_heads_per_shard
|
||||
key_value_dim = dim
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if num_shards == 1:
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||
else:
|
||||
# Sharded
|
||||
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
|
||||
print("Loading in order:", checkpoint_list)
|
||||
loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
|
||||
}
|
||||
else:
|
||||
# Sharded
|
||||
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
||||
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
||||
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
||||
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
].clone(),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
].clone(),
|
||||
}
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim),
|
||||
n_heads=n_heads,
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
||||
num_key_value_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim),
|
||||
num_key_value_heads,
|
||||
key_value_dim,
|
||||
dim,
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
||||
num_key_value_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
else:
|
||||
concat_dim = 0 if llama_version in ["3", "3.1"] else 1
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim
|
||||
),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0),
|
||||
}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
# Write configs
|
||||
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
||||
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
||||
|
||||
if llama_version in ["3", "3.1"]:
|
||||
bos_token_id = 128000
|
||||
|
||||
if instruct:
|
||||
eos_token_id = [128001, 128008, 128009]
|
||||
else:
|
||||
eos_token_id = 128001
|
||||
else:
|
||||
bos_token_id = 1
|
||||
eos_token_id = 2
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
vocab_size=vocab_size,
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
if instruct:
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.9,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
generation_config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Llama model.")
|
||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||
# Avoid saving this as part of the config.
|
||||
del model.config._name_or_path
|
||||
model.config.torch_dtype = torch.float16
|
||||
print("Saving in the Transformers format.")
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
shutil.rmtree(tmp_model_path, ignore_errors=True)
|
||||
|
||||
|
||||
class Llama3Converter(TikTokenConverter):
|
||||
def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs):
|
||||
super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs)
|
||||
tokenizer = self.converted()
|
||||
chat_template = (
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% for message in loop_messages %}"
|
||||
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
|
||||
"{% if loop.index0 == 0 %}"
|
||||
"{% set content = bos_token + content %}"
|
||||
"{% endif %}"
|
||||
"{{ content }}"
|
||||
"{% endfor %}"
|
||||
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
||||
)
|
||||
|
||||
self.tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
|
||||
chat_template=chat_template if instruct else None,
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
model_max_length=model_max_length,
|
||||
)
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False):
|
||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||
if llama_version in ["3", "3.1"]:
|
||||
tokenizer = Llama3Converter(
|
||||
input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||
).tokenizer
|
||||
else:
|
||||
tokenizer = tokenizer_class(input_tokenizer_path)
|
||||
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
return tokenizer
|
||||
CONTEXT_LENGTH_FOR_VERSION = {"Guard-3": 131072, "3.2": 131072, "3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
|
||||
|
||||
BOS_ADDED_TOKEN = AddedToken(
|
||||
"<|begin_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
|
||||
)
|
||||
EOS_ADDED_TOKEN = AddedToken(
|
||||
"<|end_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
|
||||
)
|
||||
EOT_ADDED_TOKEN = AddedToken(
|
||||
"<|eot_id|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
|
||||
)
|
||||
|
||||
DEFAULT_LLAMA_SPECIAL_TOKENS = {
|
||||
"3": [
|
||||
@ -397,14 +133,392 @@ DEFAULT_LLAMA_SPECIAL_TOKENS = {
|
||||
"<|python_tag|>",
|
||||
]
|
||||
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
|
||||
"3.2": [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
]
|
||||
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
|
||||
"Guard-3": [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
]
|
||||
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
|
||||
}
|
||||
|
||||
|
||||
def is_llama_3(version):
|
||||
return version in ["3", "3.1", "3.2", "Guard-3"]
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(
|
||||
model_path,
|
||||
input_base_path,
|
||||
model_size=None,
|
||||
safe_serialization=True,
|
||||
llama_version="1",
|
||||
vocab_size=None,
|
||||
num_shards=None,
|
||||
instruct=False,
|
||||
push_to_hub=False,
|
||||
):
|
||||
print("Converting the model.")
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = params.get("rope_theta", 10000.0)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
if base > 10000.0 and not is_llama_3(llama_version):
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_key_value_heads_per_shard = num_key_value_heads // num_shards
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_key_value_heads_per_shard = n_heads_per_shard
|
||||
key_value_dim = dim
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_model_path:
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if num_shards == 1:
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||
else:
|
||||
# Sharded
|
||||
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
|
||||
print("Loading in order:", checkpoint_list)
|
||||
loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
],
|
||||
}
|
||||
else:
|
||||
# Sharded
|
||||
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
||||
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
||||
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
||||
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
].clone(),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
].clone(),
|
||||
}
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(
|
||||
n_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim),
|
||||
n_heads=n_heads,
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
||||
num_key_value_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim),
|
||||
num_key_value_heads,
|
||||
key_value_dim,
|
||||
dim,
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
||||
num_key_value_heads_per_shard, dims_per_head, dim
|
||||
)
|
||||
for i in range(len(loaded))
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
else:
|
||||
concat_dim = 0 if is_llama_3(llama_version) else 1
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim
|
||||
),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0),
|
||||
}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
# Write configs
|
||||
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
||||
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
||||
|
||||
if is_llama_3(llama_version):
|
||||
bos_token_id = 128000
|
||||
|
||||
if instruct:
|
||||
eos_token_id = [128001, 128008, 128009]
|
||||
else:
|
||||
eos_token_id = 128001
|
||||
else:
|
||||
bos_token_id = 1
|
||||
eos_token_id = 2
|
||||
|
||||
if llama_version in ["3.1", "3.2", "Guard-3"]:
|
||||
rope_scaling = {
|
||||
"factor": 32.0 if llama_version == "3.2" else 8.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3",
|
||||
}
|
||||
else:
|
||||
rope_scaling = None
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
vocab_size=vocab_size,
|
||||
rope_theta=base,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=True if llama_version in ["3.2"] else False,
|
||||
)
|
||||
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.9,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
generation_config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Llama model.")
|
||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||
|
||||
# Avoid saving this as part of the config.
|
||||
del model.config._name_or_path
|
||||
model.config.torch_dtype = torch.float16
|
||||
|
||||
print("Saving in the Transformers format.")
|
||||
if push_to_hub:
|
||||
print("Pushing to the hub.")
|
||||
model.push_to_hub(model_path, safe_serialization=safe_serialization, private=True, use_temp_dir=True)
|
||||
else:
|
||||
print("Saving to disk.")
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
class Llama3Converter(TikTokenConverter):
|
||||
def __init__(self, vocab_file, special_tokens=None, instruct=False, llama_version="3.2", **kwargs):
|
||||
super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs)
|
||||
tokenizer = self.converted()
|
||||
|
||||
# References for chat templates in instruct models
|
||||
templates_for_version = {
|
||||
"2": ("meta-llama/Llama-2-7b-chat-hf", "f5db02db724555f92da89c216ac04704f23d4590"),
|
||||
"3": ("meta-llama/Meta-Llama-3-8B-Instruct", "5f0b02c75b57c5855da9ae460ce51323ea669d8a"),
|
||||
"3.1": ("meta-llama/Llama-3.1-8B-Instruct", "0e9e39f249a16976918f6564b8830bc894c89659"),
|
||||
"3.2": ("meta-llama/Llama-3.2-1B-Instruct", "e9f8effbab1cbdc515c11ee6e098e3d5a9f51e14"),
|
||||
"Guard-3": ("meta-llama/Llama-Guard-3-1B", "acf7aafa60f0410f8f42b1fa35e077d705892029"),
|
||||
}
|
||||
|
||||
# Add chat_template only if instruct is True.
|
||||
# Prevents a null chat_template, which triggers
|
||||
# a parsing warning in the Hub.
|
||||
additional_kwargs = {}
|
||||
if instruct or llama_version in ["Guard-3"]:
|
||||
model_id, revision = templates_for_version.get(llama_version, (None, None))
|
||||
if model_id is not None:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
t = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
||||
additional_kwargs["chat_template"] = t.chat_template
|
||||
|
||||
self.converted_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version],
|
||||
clean_up_tokenization_spaces=True,
|
||||
**additional_kwargs,
|
||||
)
|
||||
self.update_post_processor(self.converted_tokenizer)
|
||||
# finer special_tokens_map.json
|
||||
self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN
|
||||
self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN if instruct else EOS_ADDED_TOKEN
|
||||
|
||||
# We can't do this while building the tokenizer because we have no easy access to the bos token id
|
||||
def update_post_processor(self, tokenizer):
|
||||
tokenizer._tokenizer.post_processor = processors.Sequence(
|
||||
[
|
||||
processors.ByteLevel(trim_offsets=False),
|
||||
processors.TemplateProcessing(
|
||||
single="<|begin_of_text|> $A",
|
||||
pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1",
|
||||
special_tokens=[
|
||||
("<|begin_of_text|>", tokenizer.convert_tokens_to_ids("<|begin_of_text|>")),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def write_tokenizer(
|
||||
tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False, push_to_hub=False
|
||||
):
|
||||
print("Converting the tokenizer.")
|
||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||
if is_llama_3(llama_version):
|
||||
tokenizer = Llama3Converter(
|
||||
input_tokenizer_path,
|
||||
special_tokens,
|
||||
instruct,
|
||||
llama_version,
|
||||
).converted_tokenizer
|
||||
else:
|
||||
try:
|
||||
tokenizer = tokenizer_class(input_tokenizer_path)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Failed to instantiate tokenizer. Please, make sure you have sentencepiece and protobuf installed."
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.")
|
||||
tokenizer.push_to_hub(tokenizer_path, private=True, use_temp_dir=True)
|
||||
else:
|
||||
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
help="Location of Llama weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
@ -416,12 +530,18 @@ def main():
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||
parser.add_argument(
|
||||
"--llama_version",
|
||||
choices=["1", "2", "3", "3.1"],
|
||||
choices=["1", "2", "3", "3.1", "3.2", "Guard-3"],
|
||||
default="1",
|
||||
type=str,
|
||||
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
|
||||
@ -440,9 +560,9 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruct",
|
||||
action="store_true",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.",
|
||||
help="Whether the model is an instruct model or not. Will affect special tokens and chat template.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.model_size is None and args.num_shards is None:
|
||||
@ -459,8 +579,10 @@ def main():
|
||||
llama_version=args.llama_version,
|
||||
special_tokens=args.special_tokens,
|
||||
instruct=args.instruct,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
)
|
||||
|
||||
if args.model_size != "tokenizer_only":
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
@ -471,6 +593,7 @@ def main():
|
||||
vocab_size=vocab_size,
|
||||
num_shards=args.num_shards,
|
||||
instruct=args.instruct,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user