mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 12:38:23 +06:00
Convert SlimSAM checkpoints (#28379)
* First commit * Improve conversion script * Convert more checkpoints * Update src/transformers/models/sam/convert_sam_original_to_hf_format.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Rename file * More updates * Update docstring * Update script --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
c38a12270a
commit
5e4b69dc12
@ -14,6 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Convert SAM checkpoints from the original repository.
|
Convert SAM checkpoints from the original repository.
|
||||||
|
|
||||||
|
URL: https://github.com/facebookresearch/segment-anything.
|
||||||
|
|
||||||
|
Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
@ -33,6 +37,47 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(model_name):
|
||||||
|
if "slimsam-50" in model_name:
|
||||||
|
vision_config = SamVisionConfig(
|
||||||
|
hidden_size=384,
|
||||||
|
mlp_dim=1536,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
)
|
||||||
|
elif "slimsam-77" in model_name:
|
||||||
|
vision_config = SamVisionConfig(
|
||||||
|
hidden_size=168,
|
||||||
|
mlp_dim=696,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
)
|
||||||
|
elif "sam_vit_b" in model_name:
|
||||||
|
vision_config = SamVisionConfig()
|
||||||
|
elif "sam_vit_l" in model_name:
|
||||||
|
vision_config = SamVisionConfig(
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
global_attn_indexes=[5, 11, 17, 23],
|
||||||
|
)
|
||||||
|
elif "sam_vit_h" in model_name:
|
||||||
|
vision_config = SamVisionConfig(
|
||||||
|
hidden_size=1280,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=16,
|
||||||
|
global_attn_indexes=[7, 15, 23, 31],
|
||||||
|
)
|
||||||
|
|
||||||
|
config = SamConfig(
|
||||||
|
vision_config=vision_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
KEYS_TO_MODIFY_MAPPING = {
|
KEYS_TO_MODIFY_MAPPING = {
|
||||||
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
|
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
|
||||||
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
|
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
|
||||||
@ -88,63 +133,47 @@ def replace_keys(state_dict):
|
|||||||
return model_state_dict
|
return model_state_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_hub_id="ybelkada/segment-anything"):
|
def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub):
|
||||||
checkpoint_path = hf_hub_download(model_hub_id, f"checkpoints/{model_name}.pth")
|
config = get_config(model_name)
|
||||||
|
|
||||||
if "sam_vit_b" in model_name:
|
|
||||||
config = SamConfig()
|
|
||||||
elif "sam_vit_l" in model_name:
|
|
||||||
vision_config = SamVisionConfig(
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=24,
|
|
||||||
num_attention_heads=16,
|
|
||||||
global_attn_indexes=[5, 11, 17, 23],
|
|
||||||
)
|
|
||||||
|
|
||||||
config = SamConfig(
|
|
||||||
vision_config=vision_config,
|
|
||||||
)
|
|
||||||
elif "sam_vit_h" in model_name:
|
|
||||||
vision_config = SamVisionConfig(
|
|
||||||
hidden_size=1280,
|
|
||||||
num_hidden_layers=32,
|
|
||||||
num_attention_heads=16,
|
|
||||||
global_attn_indexes=[7, 15, 23, 31],
|
|
||||||
)
|
|
||||||
|
|
||||||
config = SamConfig(
|
|
||||||
vision_config=vision_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||||
state_dict = replace_keys(state_dict)
|
state_dict = replace_keys(state_dict)
|
||||||
|
|
||||||
image_processor = SamImageProcessor()
|
image_processor = SamImageProcessor()
|
||||||
|
|
||||||
processor = SamProcessor(image_processor=image_processor)
|
processor = SamProcessor(image_processor=image_processor)
|
||||||
hf_model = SamModel(config)
|
hf_model = SamModel(config)
|
||||||
|
hf_model.eval()
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
hf_model.load_state_dict(state_dict)
|
hf_model.load_state_dict(state_dict)
|
||||||
hf_model = hf_model.to("cuda")
|
hf_model = hf_model.to(device)
|
||||||
|
|
||||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
|
|
||||||
input_points = [[[400, 650]]]
|
input_points = [[[500, 375]]]
|
||||||
input_labels = [[1]]
|
input_labels = [[1]]
|
||||||
|
|
||||||
inputs = processor(images=np.array(raw_image), return_tensors="pt").to("cuda")
|
inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = hf_model(**inputs)
|
output = hf_model(**inputs)
|
||||||
scores = output.iou_scores.squeeze()
|
scores = output.iou_scores.squeeze()
|
||||||
|
|
||||||
if model_name == "sam_vit_h_4b8939":
|
if model_name == "sam_vit_b_01ec64":
|
||||||
assert scores[-1].item() == 0.579890251159668
|
|
||||||
|
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||||
).to("cuda")
|
).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = hf_model(**inputs)
|
||||||
|
scores = output.iou_scores.squeeze()
|
||||||
|
|
||||||
|
elif model_name == "sam_vit_h_4b8939":
|
||||||
|
inputs = processor(
|
||||||
|
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||||
|
).to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = hf_model(**inputs)
|
output = hf_model(**inputs)
|
||||||
@ -154,7 +183,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
|
|||||||
|
|
||||||
input_boxes = ((75, 275, 1725, 850),)
|
input_boxes = ((75, 275, 1725, 850),)
|
||||||
|
|
||||||
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
|
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = hf_model(**inputs)
|
output = hf_model(**inputs)
|
||||||
@ -168,7 +197,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
|
|||||||
|
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||||
).to("cuda")
|
).to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = hf_model(**inputs)
|
output = hf_model(**inputs)
|
||||||
@ -176,16 +205,31 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
|
|||||||
|
|
||||||
assert scores[-1].item() == 0.9936047792434692
|
assert scores[-1].item() == 0.9936047792434692
|
||||||
|
|
||||||
|
if pytorch_dump_folder is not None:
|
||||||
|
processor.save_pretrained(pytorch_dump_folder)
|
||||||
|
hf_model.save_pretrained(pytorch_dump_folder)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}"
|
||||||
|
processor.push_to_hub(repo_id)
|
||||||
|
hf_model.push_to_hub(repo_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195"]
|
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"]
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name",
|
"--model_name",
|
||||||
default="sam_vit_h_4b8939",
|
default="sam_vit_h_4b8939",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to hf config.json of model to convert",
|
help="Name of the original model to convert",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Path to the original checkpoint",
|
||||||
)
|
)
|
||||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -193,14 +237,14 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to push the model and processor to the hub after converting",
|
help="Whether to push the model and processor to the hub after converting",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--model_hub_id",
|
|
||||||
default="ybelkada/segment-anything",
|
|
||||||
choices=choices,
|
|
||||||
type=str,
|
|
||||||
help="Path to hf config.json of model to convert",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
convert_sam_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.model_hub_id)
|
if "slimsam" in args.model_name:
|
||||||
|
checkpoint_path = args.checkpoint_path
|
||||||
|
if checkpoint_path is None:
|
||||||
|
raise ValueError("You need to provide a checkpoint path for SlimSAM models.")
|
||||||
|
else:
|
||||||
|
checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth")
|
||||||
|
|
||||||
|
convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -784,7 +784,7 @@ src/transformers/models/rwkv/configuration_rwkv.py
|
|||||||
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
|
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
|
||||||
src/transformers/models/rwkv/modeling_rwkv.py
|
src/transformers/models/rwkv/modeling_rwkv.py
|
||||||
src/transformers/models/sam/configuration_sam.py
|
src/transformers/models/sam/configuration_sam.py
|
||||||
src/transformers/models/sam/convert_sam_original_to_hf_format.py
|
src/transformers/models/sam/convert_sam_to_hf.py
|
||||||
src/transformers/models/sam/image_processing_sam.py
|
src/transformers/models/sam/image_processing_sam.py
|
||||||
src/transformers/models/sam/modeling_sam.py
|
src/transformers/models/sam/modeling_sam.py
|
||||||
src/transformers/models/sam/modeling_tf_sam.py
|
src/transformers/models/sam/modeling_tf_sam.py
|
||||||
|
Loading…
Reference in New Issue
Block a user