Fix convert_internvl_weights_to_hf.py to support local paths (#38264)

fix(internvl): add local path support to convert_internvl_weights_to_hf.py
This commit is contained in:
Winston Castorp 2025-05-30 20:56:32 +08:00 committed by GitHub
parent 858ce6879a
commit d0fccbf7ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@ import argparse
import gc
import os
import re
from typing import Optional
from typing import Literal, Optional
import torch
from einops import rearrange
@ -124,6 +124,29 @@ chat_template = (
CONTEXT_LENGTH = 8192
def get_lm_type(path: str) -> Literal["qwen2", "llama"]:
"""
Determine the type of language model (either 'qwen2' or 'llama') based on a given model path.
"""
if path not in LM_TYPE_CORRESPONDENCE.keys():
base_config = AutoModel.from_pretrained(path, trust_remote_code=True).config
lm_arch = base_config.llm_config.architectures[0]
if lm_arch == "InternLM2ForCausalLM":
lm_type = "llama"
elif lm_arch == "Qwen2ForCausalLM":
lm_type = "qwen2"
else:
raise ValueError(
f"Architecture '{lm_arch}' is not supported. Only 'Qwen2ForCausalLM' and 'InternLM2ForCausalLM' are recognized."
)
else:
lm_type: Literal["qwen2", "llama"] = LM_TYPE_CORRESPONDENCE[path]
return lm_type
def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: Optional[str] = None):
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
@ -138,7 +161,7 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: O
output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n")))
old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")])
new_text = old_text_language
if LM_TYPE_CORRESPONDENCE[path] == "llama":
if get_lm_type(path) == "llama":
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items():
new_text = re.sub(pattern, replacement, new_text)
elif LM_TYPE_CORRESPONDENCE[path] == "qwen2":
@ -177,7 +200,7 @@ def get_internvl_config(input_base_path):
llm_config = base_config.llm_config.to_dict()
vision_config = base_config.vision_config.to_dict()
vision_config["use_absolute_position_embeddings"] = True
if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
if get_lm_type(input_base_path) == "qwen2":
image_token_id = 151667
language_config_class = Qwen2Config
else:
@ -188,7 +211,7 @@ def get_internvl_config(input_base_path):
# Force use_cache to True
llm_config["use_cache"] = True
# Force correct eos_token_id for InternVL3
if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
if "InternVL3" in input_base_path and get_lm_type(input_base_path) == "qwen2":
llm_config["eos_token_id"] = 151645
vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
@ -299,7 +322,7 @@ def write_model(
processor.push_to_hub(hub_dir, use_temp_dir=True)
# generation config
if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama":
if get_lm_type(input_base_path) == "llama":
print("Saving generation config...")
# in the original model, eos_token is not the same in the text_config and the generation_config
# ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config)
@ -323,7 +346,7 @@ def write_model(
def write_tokenizer(
save_dir: str, push_to_hub: bool = False, path: Optional[str] = None, hub_dir: Optional[str] = None
):
if LM_TYPE_CORRESPONDENCE[path] == "qwen2":
if get_lm_type(path) == "qwen2":
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
return_token_type_ids=False,