mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Mamba2 conversion script for original models (#32580)
* first attempt at allowing both conversions from codestral and from the original mamba ssm * allow fp16, seems default for mamba2 * dtype fix * simplify codestral check, dont overwrite pad/eos/bos when codestral * change file -> directory * use path join to be safe * style * apply code review - add util mamba2 tokenizer (gptneox with left padding) - add models dict * fix copies * add tokenizer to docs * empty commit to check for weird err * make conversion user dependent on model type, defaults for original paper models * small comment nit * remove norm_before_gate in conversion * simplify model dict by using shared keys directly + remove unnecessary attributes * fix tokenization: remove separate mamba2 tokenizer, add padding option as kwarg to gptneox one and reuse it for the conversion script * simplify even further as we pass padding side via **kwargs already
This commit is contained in:
parent
39bfb2f514
commit
92a75ff6b1
@ -15,55 +15,179 @@
|
||||
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from functools import partial
|
||||
from os import path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_model
|
||||
|
||||
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
|
||||
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
|
||||
|
||||
|
||||
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
|
||||
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
|
||||
) -> None:
|
||||
hf_config = Mamba2Config()
|
||||
hf_model = Mamba2ForCausalLM(hf_config)
|
||||
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
|
||||
# Load weights and config from paths
|
||||
original_state_dict = {}
|
||||
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
|
||||
with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
|
||||
for k in f.keys():
|
||||
newk = k.removeprefix("model.")
|
||||
original_state_dict[newk] = f.get_tensor(k).clone()
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
|
||||
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")
|
||||
|
||||
|
||||
def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
|
||||
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
|
||||
hf_config = Mamba2Config()
|
||||
|
||||
# Switch to a different dict depending on model type
|
||||
config_dict = mamba2_model_dict
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
|
||||
hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
|
||||
hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
|
||||
hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
hf_config.bos_token_id = config_dict["bos_token_id"]
|
||||
hf_config.pad_token_id = config_dict["pad_token_id"]
|
||||
hf_config.eos_token_id = config_dict["eos_token_id"]
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def load_and_save_tokenizer(
|
||||
mamba2_model_type: str,
|
||||
output_dir: str,
|
||||
tokenizer_model_path: Optional[str] = None,
|
||||
) -> None:
|
||||
tokenizer = None
|
||||
|
||||
# Load tokenizer
|
||||
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
|
||||
tokenizer_class = LlamaTokenizerFast
|
||||
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
|
||||
elif mamba2_model_type == "mamba_ssm":
|
||||
tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left")
|
||||
|
||||
# Save tokenizer
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
_MAMBA2_MODELS_DICT = {
|
||||
"codestral": {
|
||||
"hidden_size": "dim",
|
||||
"num_hidden_layers": "n_layers",
|
||||
"n_groups": "n_groups",
|
||||
"bos_token_id": 0,
|
||||
"pad_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"config_name": "params.json",
|
||||
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
|
||||
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
|
||||
},
|
||||
"mamba_ssm": {
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "n_layer",
|
||||
"n_groups": "ngroups",
|
||||
"bos_token_id": 0,
|
||||
"pad_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"config_name": "config.json",
|
||||
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
|
||||
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
|
||||
mamba2_checkpoint_path: str,
|
||||
mamba2_model_type: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_model_path: Optional[str] = None,
|
||||
) -> None:
|
||||
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
|
||||
hf_model = Mamba2ForCausalLM(hf_config)
|
||||
hf_model.load_state_dict(original_state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
|
||||
tokenizer_class = LlamaTokenizerFast
|
||||
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})
|
||||
|
||||
# Load and save tokenizer
|
||||
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba2_checkpoint_file",
|
||||
"--mamba2_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
|
||||
help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--tokenizer_model_path",
|
||||
"-m",
|
||||
"--mamba2_model_type",
|
||||
type=str,
|
||||
default="mamba_ssm",
|
||||
const="mamba_ssm",
|
||||
required=True,
|
||||
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
|
||||
choices=("codestral", "mamba_ssm"),
|
||||
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
const="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a `codestral` tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba2_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
|
||||
args.mamba2_checkpoint_directory,
|
||||
args.mamba2_model_type,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
args.tokenizer_model_path,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user