mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Convert SlimSAM checkpoints (#28379)
* First commit * Improve conversion script * Convert more checkpoints * Update src/transformers/models/sam/convert_sam_original_to_hf_format.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Rename file * More updates * Update docstring * Update script --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
c38a12270a
commit
5e4b69dc12
@ -14,6 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert SAM checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/facebookresearch/segment-anything.
|
||||
|
||||
Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
|
||||
"""
|
||||
import argparse
|
||||
import re
|
||||
@ -33,6 +37,47 @@ from transformers import (
|
||||
)
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
if "slimsam-50" in model_name:
|
||||
vision_config = SamVisionConfig(
|
||||
hidden_size=384,
|
||||
mlp_dim=1536,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
global_attn_indexes=[2, 5, 8, 11],
|
||||
)
|
||||
elif "slimsam-77" in model_name:
|
||||
vision_config = SamVisionConfig(
|
||||
hidden_size=168,
|
||||
mlp_dim=696,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
global_attn_indexes=[2, 5, 8, 11],
|
||||
)
|
||||
elif "sam_vit_b" in model_name:
|
||||
vision_config = SamVisionConfig()
|
||||
elif "sam_vit_l" in model_name:
|
||||
vision_config = SamVisionConfig(
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
global_attn_indexes=[5, 11, 17, 23],
|
||||
)
|
||||
elif "sam_vit_h" in model_name:
|
||||
vision_config = SamVisionConfig(
|
||||
hidden_size=1280,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=16,
|
||||
global_attn_indexes=[7, 15, 23, 31],
|
||||
)
|
||||
|
||||
config = SamConfig(
|
||||
vision_config=vision_config,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
|
||||
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
|
||||
@ -88,63 +133,47 @@ def replace_keys(state_dict):
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_hub_id="ybelkada/segment-anything"):
|
||||
checkpoint_path = hf_hub_download(model_hub_id, f"checkpoints/{model_name}.pth")
|
||||
|
||||
if "sam_vit_b" in model_name:
|
||||
config = SamConfig()
|
||||
elif "sam_vit_l" in model_name:
|
||||
vision_config = SamVisionConfig(
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
global_attn_indexes=[5, 11, 17, 23],
|
||||
)
|
||||
|
||||
config = SamConfig(
|
||||
vision_config=vision_config,
|
||||
)
|
||||
elif "sam_vit_h" in model_name:
|
||||
vision_config = SamVisionConfig(
|
||||
hidden_size=1280,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=16,
|
||||
global_attn_indexes=[7, 15, 23, 31],
|
||||
)
|
||||
|
||||
config = SamConfig(
|
||||
vision_config=vision_config,
|
||||
)
|
||||
def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub):
|
||||
config = get_config(model_name)
|
||||
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
state_dict = replace_keys(state_dict)
|
||||
|
||||
image_processor = SamImageProcessor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
hf_model = SamModel(config)
|
||||
hf_model.eval()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
hf_model.load_state_dict(state_dict)
|
||||
hf_model = hf_model.to("cuda")
|
||||
hf_model = hf_model.to(device)
|
||||
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
input_points = [[[400, 650]]]
|
||||
input_points = [[[500, 375]]]
|
||||
input_labels = [[1]]
|
||||
|
||||
inputs = processor(images=np.array(raw_image), return_tensors="pt").to("cuda")
|
||||
inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = hf_model(**inputs)
|
||||
scores = output.iou_scores.squeeze()
|
||||
|
||||
if model_name == "sam_vit_h_4b8939":
|
||||
assert scores[-1].item() == 0.579890251159668
|
||||
|
||||
if model_name == "sam_vit_b_01ec64":
|
||||
inputs = processor(
|
||||
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||
).to("cuda")
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = hf_model(**inputs)
|
||||
scores = output.iou_scores.squeeze()
|
||||
|
||||
elif model_name == "sam_vit_h_4b8939":
|
||||
inputs = processor(
|
||||
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = hf_model(**inputs)
|
||||
@ -154,7 +183,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
|
||||
|
||||
input_boxes = ((75, 275, 1725, 850),)
|
||||
|
||||
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
|
||||
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = hf_model(**inputs)
|
||||
@ -168,7 +197,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
|
||||
|
||||
inputs = processor(
|
||||
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||
).to("cuda")
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = hf_model(**inputs)
|
||||
@ -176,16 +205,31 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
|
||||
|
||||
assert scores[-1].item() == 0.9936047792434692
|
||||
|
||||
if pytorch_dump_folder is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder)
|
||||
hf_model.save_pretrained(pytorch_dump_folder)
|
||||
|
||||
if push_to_hub:
|
||||
repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}"
|
||||
processor.push_to_hub(repo_id)
|
||||
hf_model.push_to_hub(repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195"]
|
||||
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="sam_vit_h_4b8939",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
help="Name of the original model to convert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the original checkpoint",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
@ -193,14 +237,14 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_hub_id",
|
||||
default="ybelkada/segment-anything",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_sam_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.model_hub_id)
|
||||
if "slimsam" in args.model_name:
|
||||
checkpoint_path = args.checkpoint_path
|
||||
if checkpoint_path is None:
|
||||
raise ValueError("You need to provide a checkpoint path for SlimSAM models.")
|
||||
else:
|
||||
checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth")
|
||||
|
||||
convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -784,7 +784,7 @@ src/transformers/models/rwkv/configuration_rwkv.py
|
||||
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
|
||||
src/transformers/models/rwkv/modeling_rwkv.py
|
||||
src/transformers/models/sam/configuration_sam.py
|
||||
src/transformers/models/sam/convert_sam_original_to_hf_format.py
|
||||
src/transformers/models/sam/convert_sam_to_hf.py
|
||||
src/transformers/models/sam/image_processing_sam.py
|
||||
src/transformers/models/sam/modeling_sam.py
|
||||
src/transformers/models/sam/modeling_tf_sam.py
|
||||
|
Loading…
Reference in New Issue
Block a user