Rework a bit the LLaMA conversion script (#22236)

* Update LLaMA conversion script

* Doc

* Fix the weight size for the 13B checkpoint

* Update src/transformers/models/llama/convert_llama_weights_to_hf.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

---------

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger 2023-03-20 11:30:36 -04:00 committed by GitHub
parent 43efd7cb13
commit 786092a35e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 59 deletions

View File

@ -35,10 +35,13 @@ python src/transformers/models/llama/convert_llama_weights_to_hf.py \
```python
from transformers import LlamaForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("/output/path/tokenizer/")
model = LlamaForCausalLM.from_pretrained("/output/path/llama-7b/")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = LlamaForCausalLM.from_pretrained("/output/path")
```
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 65B model, it's thus 130GB of RAM needed.
- The LLaMA tokenizer is based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string. To have the tokenizer output the prefix space, set `decode_with_prefix_space=True` in the `LlamaTokenizer` object or in the tokenizer configuration.
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama).

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import math
import os
@ -19,22 +20,28 @@ import shutil
import torch
from transformers import LlamaConfig, LlamaForCausalLM
"""
Sample usage:
```
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
```
```
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```
tokenizer = transformers.LlamaTokenizer.from_pretrained("/output/path/tokenizer/")
```py
from transformers import LlamaForCausalLM, LlamaForTokenizer
model = transformers.LlamaForCausalLM.from_pretrained("/output/path/llama-7b/")
```
model = LlamaForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
INTERMEDIATE_SIZE_MAP = {
@ -66,8 +73,9 @@ def write_json(text, path):
def write_model(model_path, input_base_path, model_size):
assert model_size in NUM_SHARDS
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
@ -83,6 +91,7 @@ def write_model(model_path, input_base_path, model_size):
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if model_size == "7B":
# Not shared
@ -97,10 +106,7 @@ def write_model(model_path, input_base_path, model_size):
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
layer_i + 1,
n_layers + 1,
)
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if model_size == "7B":
# Unsharded
state_dict = {
@ -120,11 +126,15 @@ def write_model(model_path, input_base_path, model_size):
}
else:
# Sharded
# Note that in the 13B checkpoint, not cloning the two following weights will result in the checkpoint
# becoming 37GB instead of 26GB for some reason.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
],
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
@ -169,12 +179,9 @@ def write_model(model_path, input_base_path, model_size):
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(model_path, filename))
torch.save(state_dict, os.path.join(tmp_model_path, filename))
filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
n_layers + 1,
n_layers + 1,
)
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if model_size == "7B":
# Unsharded
state_dict = {
@ -194,48 +201,38 @@ def write_model(model_path, input_base_path, model_size):
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(model_path, filename))
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
config_out = {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": dim,
"intermediate_size": compute_intermediate_size(dim),
"initializer_range": 0.02,
"max_sequence_length": 2048,
"model_type": "llama",
"num_attention_heads": params["n_heads"],
"num_hidden_layers": params["n_layers"],
"pad_token_id": 0,
"rms_norm_eps": params["norm_eps"],
"torch_dtype": "float16",
"transformers_version": "4.27.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}
write_json(
config_out,
os.path.join(model_path, "config.json"),
)
generation_config = {
"_from_model_config": True,
"bos_token_id": 1,
"eos_token_id": 2,
"pad_token_id": 0,
"transformers_version": "4.27.0.dev0",
}
write_json(
generation_config,
os.path.join(model_path, "generation_config.json"),
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = LlamaConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# Avoid saving this as part of the config.
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(model_path)
shutil.rmtree(tmp_model_path)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
write_json(
@ -268,12 +265,12 @@ def main():
args = parser.parse_args()
if args.model_size != "tokenizer_only":
write_model(
model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
model_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
)
write_tokenizer(
tokenizer_path=os.path.join(args.output_dir, "tokenizer"),
tokenizer_path=args.output_dir,
input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
)