mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Llava conversion for models that use safetensors to store weights (#35406)
* fix llava-med-v1.5-mistral-7b conversion Signed-off-by: Isotr0py <2037008807@qq.com> * add weights_only=True Signed-off-by: Isotr0py <2037008807@qq.com> --------- Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
b2f2977533
commit
32aa2db04a
@ -15,7 +15,7 @@ import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from huggingface_hub import file_exists, hf_hub_download, snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
@ -140,11 +140,12 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
|
||||
with torch.device("meta"):
|
||||
model = LlavaForConditionalGeneration(config)
|
||||
|
||||
if "Qwen" in text_model_id:
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
else:
|
||||
# Some llava variants like microsoft/llava-med-v1.5-mistral-7b use safetensors to store weights
|
||||
if file_exists(old_state_dict_id, "model_state_dict.bin"):
|
||||
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
|
||||
state_dict = torch.load(state_dict_path, map_location="cpu")
|
||||
state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
Loading…
Reference in New Issue
Block a user