Mamba2 conversion script for original models (#32580)

* first attempt at allowing both conversions from codestral and from the original mamba ssm

* allow fp16, seems default for mamba2

* dtype fix

* simplify codestral check, dont overwrite pad/eos/bos when codestral

* change file -> directory

* use path join to be safe

* style

* apply code review
- add util mamba2 tokenizer (gptneox with left padding)
- add models dict

* fix copies

* add tokenizer to docs

* empty commit to check for weird err

* make conversion user dependent on model type, defaults for original paper models

* small comment nit

* remove norm_before_gate in conversion

* simplify model dict by using shared keys directly + remove unnecessary attributes

* fix tokenization: remove separate mamba2 tokenizer, add padding option as kwarg to gptneox one and reuse it for the conversion script

* simplify even further as we pass padding side via **kwargs already
This commit is contained in:
Anton Vlasjuk 2024-08-29 11:27:45 +02:00 committed by GitHub
parent 39bfb2f514
commit 92a75ff6b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,55 +15,179 @@
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" """This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
import argparse import argparse
import json
from functools import partial
from os import path
from typing import Dict, Optional
import torch import torch
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_model
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
def convert_mamba2_checkpoint_file_to_huggingface_model_file( def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
) -> None:
hf_config = Mamba2Config()
hf_model = Mamba2ForCausalLM(hf_config)
# Load weights and config from paths # Load weights and config from paths
original_state_dict = {} original_state_dict = {}
with safe_open(mamba2_checkpoint_path, framework="pt") as f: with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
for k in f.keys(): for k in f.keys():
newk = k.removeprefix("model.") newk = k.removeprefix("model.")
original_state_dict[newk] = f.get_tensor(k).clone() original_state_dict[newk] = f.get_tensor(k).clone()
return original_state_dict
def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")
def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
hf_config = Mamba2Config()
# Switch to a different dict depending on model type
config_dict = mamba2_model_dict
# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
hf_config.bos_token_id = config_dict["bos_token_id"]
hf_config.pad_token_id = config_dict["pad_token_id"]
hf_config.eos_token_id = config_dict["eos_token_id"]
# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm["vocab_size"]
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size
return hf_config
def load_and_save_tokenizer(
mamba2_model_type: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
tokenizer = None
# Load tokenizer
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
elif mamba2_model_type == "mamba_ssm":
tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left")
# Save tokenizer
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
_MAMBA2_MODELS_DICT = {
"codestral": {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
"n_groups": "n_groups",
"bos_token_id": 0,
"pad_token_id": 1,
"eos_token_id": 2,
"config_name": "params.json",
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
},
"mamba_ssm": {
"hidden_size": "d_model",
"num_hidden_layers": "n_layer",
"n_groups": "ngroups",
"bos_token_id": 0,
"pad_token_id": 0,
"eos_token_id": 0,
"config_name": "config.json",
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
},
}
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str,
mamba2_model_type: str,
precision: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]
# Load and save config based on name
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
hf_config.save_pretrained(output_dir)
# Load state dict of the original model and transfer to hf model
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
hf_model = Mamba2ForCausalLM(hf_config)
hf_model.load_state_dict(original_state_dict) hf_model.load_state_dict(original_state_dict)
# Save new model to pytorch_dump_path # Save new model to pytorch_dump_path
hf_model.to(torch.bfloat16).save_pretrained(output_dir) dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
tokenizer_class = LlamaTokenizerFast save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
tokenizer.save_pretrained(output_dir) # Load and save tokenizer
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-i", "-i",
"--mamba2_checkpoint_file", "--mamba2_checkpoint_directory",
type=str, type=str,
required=True, required=True,
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.", help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
) )
parser.add_argument( parser.add_argument(
"-c", "-m",
"--tokenizer_model_path", "--mamba2_model_type",
type=str, type=str,
default="mamba_ssm",
const="mamba_ssm",
required=True, required=True,
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.", choices=("codestral", "mamba_ssm"),
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
)
parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
const="fp16",
required=True,
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
) )
parser.add_argument( parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
) )
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a `codestral` tokenizer file.",
)
args = parser.parse_args() args = parser.parse_args()
convert_mamba2_checkpoint_file_to_huggingface_model_file( convert_mamba2_checkpoint_file_to_huggingface_model_file(
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir args.mamba2_checkpoint_directory,
args.mamba2_model_type,
args.precision,
args.output_dir,
args.tokenizer_model_path,
) )