mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
FEAT: Add mistral v3 conversion script (#30981)
* add mistral v3 conversion script * Update src/transformers/models/mistral/convert_mistral_weights_to_hf.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
d521ba5797
commit
bfe6f513b9
@ -19,6 +19,7 @@ import shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
from transformers import (
|
||||
LlamaTokenizer,
|
||||
@ -76,7 +77,7 @@ def write_json(text, path):
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
|
||||
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, is_v3=False):
|
||||
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
|
||||
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
|
||||
input_base_path = os.path.join(input_base_path, model_size)
|
||||
@ -88,8 +89,12 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
|
||||
sliding_window = params.get("sliding_window", None)
|
||||
|
||||
# For some reason this is a string in the params.json
|
||||
sliding_window = int(params["sliding_window"])
|
||||
if sliding_window is not None:
|
||||
sliding_window = int(sliding_window)
|
||||
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
@ -100,7 +105,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
max_position_embeddings = 4096 * 8
|
||||
|
||||
if tokenizer_path is not None:
|
||||
tokenizer = tokenizer_class(tokenizer_path)
|
||||
tokenizer = tokenizer_class(tokenizer_path + ".v3" if is_v3 else "")
|
||||
tokenizer.save_pretrained(model_path)
|
||||
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
|
||||
|
||||
@ -118,11 +123,15 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
# Load weights - for v3 models the consolidated weights are in a single file format in safetensors
|
||||
if is_v3:
|
||||
loaded = [safe_load_file(os.path.join(input_base_path, "consolidated.safetensors"))]
|
||||
else:
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
@ -231,6 +240,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
del model.config._name_or_path
|
||||
model.config.torch_dtype = torch.float16
|
||||
print("Saving in the Transformers format.")
|
||||
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
shutil.rmtree(tmp_model_path)
|
||||
|
||||
@ -258,6 +268,9 @@ def main():
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
|
||||
parser.add_argument(
|
||||
"--is_v3", action="store_true", help="Whether the checkpoints correspond to the 3rd version or not."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||
if args.model_size != "tokenizer_only":
|
||||
@ -267,6 +280,7 @@ def main():
|
||||
model_size=args.model_size,
|
||||
safe_serialization=args.safe_serialization,
|
||||
tokenizer_path=spm_path,
|
||||
is_v3=args.is_v3,
|
||||
)
|
||||
else:
|
||||
write_tokenizer(args.output_dir, spm_path)
|
||||
|
Loading…
Reference in New Issue
Block a user