Fix convert_internvl_weights_to_hf.py to support local paths (#38264)

fix(internvl): add local path support to convert_internvl_weights_to_hf.py
This commit is contained in:
Winston Castorp 2025-05-30 20:56:32 +08:00 committed by GitHub
parent 858ce6879a
commit d0fccbf7ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@ import argparse
import gc import gc
import os import os
import re import re
from typing import Optional from typing import Literal, Optional
import torch import torch
from einops import rearrange from einops import rearrange
@ -124,6 +124,29 @@ chat_template = (
CONTEXT_LENGTH = 8192 CONTEXT_LENGTH = 8192
def get_lm_type(path: str) -> Literal["qwen2", "llama"]:
"""
Determine the type of language model (either 'qwen2' or 'llama') based on a given model path.
"""
if path not in LM_TYPE_CORRESPONDENCE.keys():
base_config = AutoModel.from_pretrained(path, trust_remote_code=True).config
lm_arch = base_config.llm_config.architectures[0]
if lm_arch == "InternLM2ForCausalLM":
lm_type = "llama"
elif lm_arch == "Qwen2ForCausalLM":
lm_type = "qwen2"
else:
raise ValueError(
f"Architecture '{lm_arch}' is not supported. Only 'Qwen2ForCausalLM' and 'InternLM2ForCausalLM' are recognized."
)
else:
lm_type: Literal["qwen2", "llama"] = LM_TYPE_CORRESPONDENCE[path]
return lm_type
def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: Optional[str] = None): def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: Optional[str] = None):
""" """
This function should be applied only once, on the concatenated keys to efficiently rename using This function should be applied only once, on the concatenated keys to efficiently rename using
@ -138,7 +161,7 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: O
output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n"))) output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n")))
old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")]) old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")])
new_text = old_text_language new_text = old_text_language
if LM_TYPE_CORRESPONDENCE[path] == "llama": if get_lm_type(path) == "llama":
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items(): for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items():
new_text = re.sub(pattern, replacement, new_text) new_text = re.sub(pattern, replacement, new_text)
elif LM_TYPE_CORRESPONDENCE[path] == "qwen2": elif LM_TYPE_CORRESPONDENCE[path] == "qwen2":
@ -177,7 +200,7 @@ def get_internvl_config(input_base_path):
llm_config = base_config.llm_config.to_dict() llm_config = base_config.llm_config.to_dict()
vision_config = base_config.vision_config.to_dict() vision_config = base_config.vision_config.to_dict()
vision_config["use_absolute_position_embeddings"] = True vision_config["use_absolute_position_embeddings"] = True
if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2": if get_lm_type(input_base_path) == "qwen2":
image_token_id = 151667 image_token_id = 151667
language_config_class = Qwen2Config language_config_class = Qwen2Config
else: else:
@ -188,7 +211,7 @@ def get_internvl_config(input_base_path):
# Force use_cache to True # Force use_cache to True
llm_config["use_cache"] = True llm_config["use_cache"] = True
# Force correct eos_token_id for InternVL3 # Force correct eos_token_id for InternVL3
if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2": if "InternVL3" in input_base_path and get_lm_type(input_base_path) == "qwen2":
llm_config["eos_token_id"] = 151645 llm_config["eos_token_id"] = 151645
vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS} vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
@ -299,7 +322,7 @@ def write_model(
processor.push_to_hub(hub_dir, use_temp_dir=True) processor.push_to_hub(hub_dir, use_temp_dir=True)
# generation config # generation config
if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama": if get_lm_type(input_base_path) == "llama":
print("Saving generation config...") print("Saving generation config...")
# in the original model, eos_token is not the same in the text_config and the generation_config # in the original model, eos_token is not the same in the text_config and the generation_config
# ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config) # ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config)
@ -323,7 +346,7 @@ def write_model(
def write_tokenizer( def write_tokenizer(
save_dir: str, push_to_hub: bool = False, path: Optional[str] = None, hub_dir: Optional[str] = None save_dir: str, push_to_hub: bool = False, path: Optional[str] = None, hub_dir: Optional[str] = None
): ):
if LM_TYPE_CORRESPONDENCE[path] == "qwen2": if get_lm_type(path) == "qwen2":
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct",
return_token_type_ids=False, return_token_type_ids=False,