mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Rework a bit the LLaMA conversion script (#22236)
* Update LLaMA conversion script * Doc * Fix the weight size for the 13B checkpoint * Update src/transformers/models/llama/convert_llama_weights_to_hf.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> --------- Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
43efd7cb13
commit
786092a35e
@ -35,10 +35,13 @@ python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
```python
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/output/path/tokenizer/")
|
||||
model = LlamaForCausalLM.from_pretrained("/output/path/llama-7b/")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
||||
model = LlamaForCausalLM.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 65B model, it's thus 130GB of RAM needed.
|
||||
|
||||
- The LLaMA tokenizer is based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string. To have the tokenizer output the prefix space, set `decode_with_prefix_space=True` in the `LlamaTokenizer` object or in the tokenizer configuration.
|
||||
|
||||
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama).
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@ -19,22 +20,28 @@ import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
```
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```
|
||||
tokenizer = transformers.LlamaTokenizer.from_pretrained("/output/path/tokenizer/")
|
||||
```py
|
||||
from transformers import LlamaForCausalLM, LlamaForTokenizer
|
||||
|
||||
model = transformers.LlamaForCausalLM.from_pretrained("/output/path/llama-7b/")
|
||||
```
|
||||
model = LlamaForCausalLM.from_pretrained("/output/path")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
"""
|
||||
|
||||
INTERMEDIATE_SIZE_MAP = {
|
||||
@ -66,8 +73,9 @@ def write_json(text, path):
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size):
|
||||
assert model_size in NUM_SHARDS
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
os.makedirs(tmp_model_path, exist_ok=True)
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
@ -83,6 +91,7 @@ def write_model(model_path, input_base_path, model_size):
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if model_size == "7B":
|
||||
# Not shared
|
||||
@ -97,10 +106,7 @@ def write_model(model_path, input_base_path, model_size):
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
|
||||
layer_i + 1,
|
||||
n_layers + 1,
|
||||
)
|
||||
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
@ -120,11 +126,15 @@ def write_model(model_path, input_base_path, model_size):
|
||||
}
|
||||
else:
|
||||
# Sharded
|
||||
# Note that in the 13B checkpoint, not cloning the two following weights will result in the checkpoint
|
||||
# becoming 37GB instead of 26GB for some reason.
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][f"layers.{layer_i}.attention_norm.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
].clone(),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
],
|
||||
].clone(),
|
||||
}
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
@ -169,12 +179,9 @@ def write_model(model_path, input_base_path, model_size):
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(model_path, filename))
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
|
||||
n_layers + 1,
|
||||
n_layers + 1,
|
||||
)
|
||||
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
@ -194,48 +201,38 @@ def write_model(model_path, input_base_path, model_size):
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(model_path, filename))
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
|
||||
# Write configs
|
||||
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||
write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
|
||||
config_out = {
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": dim,
|
||||
"intermediate_size": compute_intermediate_size(dim),
|
||||
"initializer_range": 0.02,
|
||||
"max_sequence_length": 2048,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": params["n_heads"],
|
||||
"num_hidden_layers": params["n_layers"],
|
||||
"pad_token_id": 0,
|
||||
"rms_norm_eps": params["norm_eps"],
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.27.0.dev0",
|
||||
"use_cache": True,
|
||||
"vocab_size": 32000,
|
||||
}
|
||||
write_json(
|
||||
config_out,
|
||||
os.path.join(model_path, "config.json"),
|
||||
)
|
||||
generation_config = {
|
||||
"_from_model_config": True,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"pad_token_id": 0,
|
||||
"transformers_version": "4.27.0.dev0",
|
||||
}
|
||||
write_json(
|
||||
generation_config,
|
||||
os.path.join(model_path, "generation_config.json"),
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim),
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
)
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Llama model.")
|
||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
# Avoid saving this as part of the config.
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving in the Transformers format.")
|
||||
model.save_pretrained(model_path)
|
||||
shutil.rmtree(tmp_model_path)
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
||||
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
|
||||
write_json(
|
||||
@ -268,12 +265,12 @@ def main():
|
||||
args = parser.parse_args()
|
||||
if args.model_size != "tokenizer_only":
|
||||
write_model(
|
||||
model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
|
||||
model_path=args.output_dir,
|
||||
input_base_path=os.path.join(args.input_dir, args.model_size),
|
||||
model_size=args.model_size,
|
||||
)
|
||||
write_tokenizer(
|
||||
tokenizer_path=os.path.join(args.output_dir, "tokenizer"),
|
||||
tokenizer_path=args.output_dir,
|
||||
input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user