Convert SlimSAM checkpoints (#28379)

* First commit

* Improve conversion script

* Convert more checkpoints

* Update src/transformers/models/sam/convert_sam_original_to_hf_format.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Rename file

* More updates

* Update docstring

* Update script

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
NielsRogge 2024-03-04 11:51:16 +01:00 committed by GitHub
parent c38a12270a
commit 5e4b69dc12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 48 deletions

View File

@ -14,6 +14,10 @@
# limitations under the License.
"""
Convert SAM checkpoints from the original repository.
URL: https://github.com/facebookresearch/segment-anything.
Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
"""
import argparse
import re
@ -33,6 +37,47 @@ from transformers import (
)
def get_config(model_name):
if "slimsam-50" in model_name:
vision_config = SamVisionConfig(
hidden_size=384,
mlp_dim=1536,
num_hidden_layers=12,
num_attention_heads=12,
global_attn_indexes=[2, 5, 8, 11],
)
elif "slimsam-77" in model_name:
vision_config = SamVisionConfig(
hidden_size=168,
mlp_dim=696,
num_hidden_layers=12,
num_attention_heads=12,
global_attn_indexes=[2, 5, 8, 11],
)
elif "sam_vit_b" in model_name:
vision_config = SamVisionConfig()
elif "sam_vit_l" in model_name:
vision_config = SamVisionConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
global_attn_indexes=[5, 11, 17, 23],
)
elif "sam_vit_h" in model_name:
vision_config = SamVisionConfig(
hidden_size=1280,
num_hidden_layers=32,
num_attention_heads=16,
global_attn_indexes=[7, 15, 23, 31],
)
config = SamConfig(
vision_config=vision_config,
)
return config
KEYS_TO_MODIFY_MAPPING = {
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
@ -88,63 +133,47 @@ def replace_keys(state_dict):
return model_state_dict
def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_hub_id="ybelkada/segment-anything"):
checkpoint_path = hf_hub_download(model_hub_id, f"checkpoints/{model_name}.pth")
if "sam_vit_b" in model_name:
config = SamConfig()
elif "sam_vit_l" in model_name:
vision_config = SamVisionConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
global_attn_indexes=[5, 11, 17, 23],
)
config = SamConfig(
vision_config=vision_config,
)
elif "sam_vit_h" in model_name:
vision_config = SamVisionConfig(
hidden_size=1280,
num_hidden_layers=32,
num_attention_heads=16,
global_attn_indexes=[7, 15, 23, 31],
)
config = SamConfig(
vision_config=vision_config,
)
def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub):
config = get_config(model_name)
state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = replace_keys(state_dict)
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor=image_processor)
hf_model = SamModel(config)
hf_model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
hf_model.load_state_dict(state_dict)
hf_model = hf_model.to("cuda")
hf_model = hf_model.to(device)
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[400, 650]]]
input_points = [[[500, 375]]]
input_labels = [[1]]
inputs = processor(images=np.array(raw_image), return_tensors="pt").to("cuda")
inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device)
with torch.no_grad():
output = hf_model(**inputs)
scores = output.iou_scores.squeeze()
if model_name == "sam_vit_h_4b8939":
assert scores[-1].item() == 0.579890251159668
if model_name == "sam_vit_b_01ec64":
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to("cuda")
).to(device)
with torch.no_grad():
output = hf_model(**inputs)
scores = output.iou_scores.squeeze()
elif model_name == "sam_vit_h_4b8939":
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(device)
with torch.no_grad():
output = hf_model(**inputs)
@ -154,7 +183,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
input_boxes = ((75, 275, 1725, 850),)
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device)
with torch.no_grad():
output = hf_model(**inputs)
@ -168,7 +197,7 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to("cuda")
).to(device)
with torch.no_grad():
output = hf_model(**inputs)
@ -176,16 +205,31 @@ def convert_sam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_h
assert scores[-1].item() == 0.9936047792434692
if pytorch_dump_folder is not None:
processor.save_pretrained(pytorch_dump_folder)
hf_model.save_pretrained(pytorch_dump_folder)
if push_to_hub:
repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}"
processor.push_to_hub(repo_id)
hf_model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195"]
choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"]
parser.add_argument(
"--model_name",
default="sam_vit_h_4b8939",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
help="Name of the original model to convert",
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=False,
help="Path to the original checkpoint",
)
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
@ -193,14 +237,14 @@ if __name__ == "__main__":
action="store_true",
help="Whether to push the model and processor to the hub after converting",
)
parser.add_argument(
"--model_hub_id",
default="ybelkada/segment-anything",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
)
args = parser.parse_args()
convert_sam_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.model_hub_id)
if "slimsam" in args.model_name:
checkpoint_path = args.checkpoint_path
if checkpoint_path is None:
raise ValueError("You need to provide a checkpoint path for SlimSAM models.")
else:
checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth")
convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -784,7 +784,7 @@ src/transformers/models/rwkv/configuration_rwkv.py
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/sam/configuration_sam.py
src/transformers/models/sam/convert_sam_original_to_hf_format.py
src/transformers/models/sam/convert_sam_to_hf.py
src/transformers/models/sam/image_processing_sam.py
src/transformers/models/sam/modeling_sam.py
src/transformers/models/sam/modeling_tf_sam.py