mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix convert_internvl_weights_to_hf.py to support local paths (#38264)
fix(internvl): add local path support to convert_internvl_weights_to_hf.py
This commit is contained in:
parent
858ce6879a
commit
d0fccbf7ef
@ -15,7 +15,7 @@ import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@ -124,6 +124,29 @@ chat_template = (
|
||||
CONTEXT_LENGTH = 8192
|
||||
|
||||
|
||||
def get_lm_type(path: str) -> Literal["qwen2", "llama"]:
|
||||
"""
|
||||
Determine the type of language model (either 'qwen2' or 'llama') based on a given model path.
|
||||
"""
|
||||
if path not in LM_TYPE_CORRESPONDENCE.keys():
|
||||
base_config = AutoModel.from_pretrained(path, trust_remote_code=True).config
|
||||
|
||||
lm_arch = base_config.llm_config.architectures[0]
|
||||
|
||||
if lm_arch == "InternLM2ForCausalLM":
|
||||
lm_type = "llama"
|
||||
elif lm_arch == "Qwen2ForCausalLM":
|
||||
lm_type = "qwen2"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Architecture '{lm_arch}' is not supported. Only 'Qwen2ForCausalLM' and 'InternLM2ForCausalLM' are recognized."
|
||||
)
|
||||
else:
|
||||
lm_type: Literal["qwen2", "llama"] = LM_TYPE_CORRESPONDENCE[path]
|
||||
|
||||
return lm_type
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: Optional[str] = None):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
@ -138,7 +161,7 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: O
|
||||
output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n")))
|
||||
old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")])
|
||||
new_text = old_text_language
|
||||
if LM_TYPE_CORRESPONDENCE[path] == "llama":
|
||||
if get_lm_type(path) == "llama":
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items():
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
elif LM_TYPE_CORRESPONDENCE[path] == "qwen2":
|
||||
@ -177,7 +200,7 @@ def get_internvl_config(input_base_path):
|
||||
llm_config = base_config.llm_config.to_dict()
|
||||
vision_config = base_config.vision_config.to_dict()
|
||||
vision_config["use_absolute_position_embeddings"] = True
|
||||
if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
|
||||
if get_lm_type(input_base_path) == "qwen2":
|
||||
image_token_id = 151667
|
||||
language_config_class = Qwen2Config
|
||||
else:
|
||||
@ -188,7 +211,7 @@ def get_internvl_config(input_base_path):
|
||||
# Force use_cache to True
|
||||
llm_config["use_cache"] = True
|
||||
# Force correct eos_token_id for InternVL3
|
||||
if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
|
||||
if "InternVL3" in input_base_path and get_lm_type(input_base_path) == "qwen2":
|
||||
llm_config["eos_token_id"] = 151645
|
||||
|
||||
vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||
@ -299,7 +322,7 @@ def write_model(
|
||||
processor.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
|
||||
# generation config
|
||||
if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama":
|
||||
if get_lm_type(input_base_path) == "llama":
|
||||
print("Saving generation config...")
|
||||
# in the original model, eos_token is not the same in the text_config and the generation_config
|
||||
# ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config)
|
||||
@ -323,7 +346,7 @@ def write_model(
|
||||
def write_tokenizer(
|
||||
save_dir: str, push_to_hub: bool = False, path: Optional[str] = None, hub_dir: Optional[str] = None
|
||||
):
|
||||
if LM_TYPE_CORRESPONDENCE[path] == "qwen2":
|
||||
if get_lm_type(path) == "qwen2":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
return_token_type_ids=False,
|
||||
|
Loading…
Reference in New Issue
Block a user