add fast support and option (#22724)

* add fast support and option

* update based on review

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/llama/convert_llama_weights_to_hf.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* nit

* add print

* fixup

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Arthur 2023-04-12 18:10:04 +02:00 committed by GitHub
parent 10fab90fe2
commit 9858195481
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,12 +17,22 @@ import json
import math
import os
import shutil
import warnings
import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
Sample usage:
@ -232,9 +242,10 @@ def write_model(model_path, input_base_path, model_size):
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
# Initialize the tokenizer based on the `spm` model
tokenizer = LlamaTokenizer(input_tokenizer_path)
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
print("Saving a {tokenizer_class} to {tokenizer_path}")
tokenizer = tokenizer_class(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
@ -259,10 +270,8 @@ def main():
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
)
write_tokenizer(
tokenizer_path=args.output_dir,
input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
)
spm_path = os.path.join(args.input_dir, "tokenizer.model")
write_tokenizer(args.output_dir, spm_path)
if __name__ == "__main__":