[GIT] Convert more checkpoints (#21245)

* Extend conversion script

* Remove print statement

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2023-01-23 15:19:27 +01:00 committed by GitHub
parent 66459ce319
commit 6e4d3f0859
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -246,6 +246,11 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
"git-large-msrvtt-qa": (
"https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_MSRVTT_QA/snapshot/model.pt"
),
"git-large-r": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R/snapshot/model.pt",
"git-large-r-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_COCO/snapshot/model.pt",
"git-large-r-textcaps": (
"https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_TEXTCAPS/snapshot/model.pt"
),
}
model_name_to_path = {
@ -258,7 +263,7 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
# define GIT configuration based on model name
config, image_size, is_video = get_git_config(model_name)
if "large" in model_name and not is_video:
if "large" in model_name and not is_video and "large-r" not in model_name:
# large checkpoints take way too long to download
checkpoint_path = model_name_to_path[model_name]
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
@ -349,6 +354,12 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
expected_slice_logits = torch.tensor([-1.0113, -1.0114, -1.0113])
elif model_name == "git-large-msrvtt-qa":
expected_slice_logits = torch.tensor([0.0130, 0.0134, 0.0131])
elif model_name == "git-large-r":
expected_slice_logits = torch.tensor([-1.1283, -1.1285, -1.1286])
elif model_name == "git-large-r-coco":
expected_slice_logits = torch.tensor([-0.9641, -0.9641, -0.9641])
elif model_name == "git-large-r-textcaps":
expected_slice_logits = torch.tensor([-1.1121, -1.1120, -1.1124])
assert torch.allclose(logits[0, -1, :3], expected_slice_logits, atol=1e-4)
print("Looks ok!")