mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
add fast support and option (#22724)
* add fast support and option * update based on review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/llama/convert_llama_weights_to_hf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * nit * add print * fixup --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
10fab90fe2
commit
9858195481
@ -17,12 +17,22 @@ import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
except ImportError as e:
|
||||
warnings.warn(e)
|
||||
warnings.warn(
|
||||
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
||||
)
|
||||
LlamaTokenizerFast = None
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
@ -232,9 +242,10 @@ def write_model(model_path, input_base_path, model_size):
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
||||
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
|
||||
# Initialize the tokenizer based on the `spm` model
|
||||
tokenizer = LlamaTokenizer(input_tokenizer_path)
|
||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||
print("Saving a {tokenizer_class} to {tokenizer_path}")
|
||||
tokenizer = tokenizer_class(input_tokenizer_path)
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
|
||||
|
||||
@ -259,10 +270,8 @@ def main():
|
||||
input_base_path=os.path.join(args.input_dir, args.model_size),
|
||||
model_size=args.model_size,
|
||||
)
|
||||
write_tokenizer(
|
||||
tokenizer_path=args.output_dir,
|
||||
input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
|
||||
)
|
||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||
write_tokenizer(args.output_dir, spm_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user