Fix Llava conversion for models that use safetensors to store weights (#35406)

* fix llava-med-v1.5-mistral-7b conversion

Signed-off-by: Isotr0py <2037008807@qq.com>

* add weights_only=True

Signed-off-by: Isotr0py <2037008807@qq.com>

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-01-06 16:59:38 +08:00 committed by GitHub
parent b2f2977533
commit 32aa2db04a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@ import argparse
import glob import glob
import torch import torch
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import file_exists, hf_hub_download, snapshot_download
from safetensors import safe_open from safetensors import safe_open
from transformers import ( from transformers import (
@ -140,11 +140,12 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
with torch.device("meta"): with torch.device("meta"):
model = LlavaForConditionalGeneration(config) model = LlavaForConditionalGeneration(config)
if "Qwen" in text_model_id: # Some llava variants like microsoft/llava-med-v1.5-mistral-7b use safetensors to store weights
state_dict = load_original_state_dict(old_state_dict_id) if file_exists(old_state_dict_id, "model_state_dict.bin"):
else:
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin") state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
state_dict = torch.load(state_dict_path, map_location="cpu") state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True)
else:
state_dict = load_original_state_dict(old_state_dict_id)
state_dict = convert_state_dict_to_hf(state_dict) state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=True, assign=True) model.load_state_dict(state_dict, strict=True, assign=True)