diff --git a/src/transformers/models/internvl/convert_internvl_weights_to_hf.py b/src/transformers/models/internvl/convert_internvl_weights_to_hf.py index fa6d4bc9e52..a1437266688 100644 --- a/src/transformers/models/internvl/convert_internvl_weights_to_hf.py +++ b/src/transformers/models/internvl/convert_internvl_weights_to_hf.py @@ -15,7 +15,7 @@ import argparse import gc import os import re -from typing import Optional +from typing import Literal, Optional import torch from einops import rearrange @@ -124,6 +124,29 @@ chat_template = ( CONTEXT_LENGTH = 8192 +def get_lm_type(path: str) -> Literal["qwen2", "llama"]: + """ + Determine the type of language model (either 'qwen2' or 'llama') based on a given model path. + """ + if path not in LM_TYPE_CORRESPONDENCE.keys(): + base_config = AutoModel.from_pretrained(path, trust_remote_code=True).config + + lm_arch = base_config.llm_config.architectures[0] + + if lm_arch == "InternLM2ForCausalLM": + lm_type = "llama" + elif lm_arch == "Qwen2ForCausalLM": + lm_type = "qwen2" + else: + raise ValueError( + f"Architecture '{lm_arch}' is not supported. Only 'Qwen2ForCausalLM' and 'InternLM2ForCausalLM' are recognized." + ) + else: + lm_type: Literal["qwen2", "llama"] = LM_TYPE_CORRESPONDENCE[path] + + return lm_type + + def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: Optional[str] = None): """ This function should be applied only once, on the concatenated keys to efficiently rename using @@ -138,7 +161,7 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None, path: O output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n"))) old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")]) new_text = old_text_language - if LM_TYPE_CORRESPONDENCE[path] == "llama": + if get_lm_type(path) == "llama": for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items(): new_text = re.sub(pattern, replacement, new_text) elif LM_TYPE_CORRESPONDENCE[path] == "qwen2": @@ -177,7 +200,7 @@ def get_internvl_config(input_base_path): llm_config = base_config.llm_config.to_dict() vision_config = base_config.vision_config.to_dict() vision_config["use_absolute_position_embeddings"] = True - if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2": + if get_lm_type(input_base_path) == "qwen2": image_token_id = 151667 language_config_class = Qwen2Config else: @@ -188,7 +211,7 @@ def get_internvl_config(input_base_path): # Force use_cache to True llm_config["use_cache"] = True # Force correct eos_token_id for InternVL3 - if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2": + if "InternVL3" in input_base_path and get_lm_type(input_base_path) == "qwen2": llm_config["eos_token_id"] = 151645 vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS} @@ -299,7 +322,7 @@ def write_model( processor.push_to_hub(hub_dir, use_temp_dir=True) # generation config - if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama": + if get_lm_type(input_base_path) == "llama": print("Saving generation config...") # in the original model, eos_token is not the same in the text_config and the generation_config # ("" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config) @@ -323,7 +346,7 @@ def write_model( def write_tokenizer( save_dir: str, push_to_hub: bool = False, path: Optional[str] = None, hub_dir: Optional[str] = None ): - if LM_TYPE_CORRESPONDENCE[path] == "qwen2": + if get_lm_type(path) == "qwen2": tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen2.5-VL-7B-Instruct", return_token_type_ids=False,