mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
FEAT: Add mistral v3 conversion script (#30981)
* add mistral v3 conversion script * Update src/transformers/models/mistral/convert_mistral_weights_to_hf.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
d521ba5797
commit
bfe6f513b9
@ -19,6 +19,7 @@ import shutil
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import load_file as safe_load_file
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
@ -76,7 +77,7 @@ def write_json(text, path):
|
|||||||
json.dump(text, f)
|
json.dump(text, f)
|
||||||
|
|
||||||
|
|
||||||
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
|
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, is_v3=False):
|
||||||
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
|
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
|
||||||
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
|
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
|
||||||
input_base_path = os.path.join(input_base_path, model_size)
|
input_base_path = os.path.join(input_base_path, model_size)
|
||||||
@ -88,8 +89,12 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
|||||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||||
num_shards = NUM_SHARDS[model_size]
|
num_shards = NUM_SHARDS[model_size]
|
||||||
|
|
||||||
|
sliding_window = params.get("sliding_window", None)
|
||||||
|
|
||||||
# For some reason this is a string in the params.json
|
# For some reason this is a string in the params.json
|
||||||
sliding_window = int(params["sliding_window"])
|
if sliding_window is not None:
|
||||||
|
sliding_window = int(sliding_window)
|
||||||
|
|
||||||
n_layers = params["n_layers"]
|
n_layers = params["n_layers"]
|
||||||
n_heads = params["n_heads"]
|
n_heads = params["n_heads"]
|
||||||
n_heads_per_shard = n_heads // num_shards
|
n_heads_per_shard = n_heads // num_shards
|
||||||
@ -100,7 +105,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
|||||||
max_position_embeddings = 4096 * 8
|
max_position_embeddings = 4096 * 8
|
||||||
|
|
||||||
if tokenizer_path is not None:
|
if tokenizer_path is not None:
|
||||||
tokenizer = tokenizer_class(tokenizer_path)
|
tokenizer = tokenizer_class(tokenizer_path + ".v3" if is_v3 else "")
|
||||||
tokenizer.save_pretrained(model_path)
|
tokenizer.save_pretrained(model_path)
|
||||||
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
|
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
|
||||||
|
|
||||||
@ -118,11 +123,15 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
|||||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||||
|
|
||||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||||
# Load weights
|
|
||||||
loaded = [
|
# Load weights - for v3 models the consolidated weights are in a single file format in safetensors
|
||||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
if is_v3:
|
||||||
for i in range(num_shards)
|
loaded = [safe_load_file(os.path.join(input_base_path, "consolidated.safetensors"))]
|
||||||
]
|
else:
|
||||||
|
loaded = [
|
||||||
|
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||||
|
for i in range(num_shards)
|
||||||
|
]
|
||||||
param_count = 0
|
param_count = 0
|
||||||
index_dict = {"weight_map": {}}
|
index_dict = {"weight_map": {}}
|
||||||
for layer_i in range(n_layers):
|
for layer_i in range(n_layers):
|
||||||
@ -231,6 +240,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
|||||||
del model.config._name_or_path
|
del model.config._name_or_path
|
||||||
model.config.torch_dtype = torch.float16
|
model.config.torch_dtype = torch.float16
|
||||||
print("Saving in the Transformers format.")
|
print("Saving in the Transformers format.")
|
||||||
|
|
||||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||||
shutil.rmtree(tmp_model_path)
|
shutil.rmtree(tmp_model_path)
|
||||||
|
|
||||||
@ -258,6 +268,9 @@ def main():
|
|||||||
help="Location to write HF model and tokenizer",
|
help="Location to write HF model and tokenizer",
|
||||||
)
|
)
|
||||||
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
|
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--is_v3", action="store_true", help="Whether the checkpoints correspond to the 3rd version or not."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||||
if args.model_size != "tokenizer_only":
|
if args.model_size != "tokenizer_only":
|
||||||
@ -267,6 +280,7 @@ def main():
|
|||||||
model_size=args.model_size,
|
model_size=args.model_size,
|
||||||
safe_serialization=args.safe_serialization,
|
safe_serialization=args.safe_serialization,
|
||||||
tokenizer_path=spm_path,
|
tokenizer_path=spm_path,
|
||||||
|
is_v3=args.is_v3,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
write_tokenizer(args.output_dir, spm_path)
|
write_tokenizer(args.output_dir, spm_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user