mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[GIT] Convert more checkpoints (#21245)
* Extend conversion script * Remove print statement Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
66459ce319
commit
6e4d3f0859
@ -246,6 +246,11 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
|
||||
"git-large-msrvtt-qa": (
|
||||
"https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_MSRVTT_QA/snapshot/model.pt"
|
||||
),
|
||||
"git-large-r": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R/snapshot/model.pt",
|
||||
"git-large-r-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_COCO/snapshot/model.pt",
|
||||
"git-large-r-textcaps": (
|
||||
"https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_TEXTCAPS/snapshot/model.pt"
|
||||
),
|
||||
}
|
||||
|
||||
model_name_to_path = {
|
||||
@ -258,7 +263,7 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
|
||||
|
||||
# define GIT configuration based on model name
|
||||
config, image_size, is_video = get_git_config(model_name)
|
||||
if "large" in model_name and not is_video:
|
||||
if "large" in model_name and not is_video and "large-r" not in model_name:
|
||||
# large checkpoints take way too long to download
|
||||
checkpoint_path = model_name_to_path[model_name]
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
@ -349,6 +354,12 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
|
||||
expected_slice_logits = torch.tensor([-1.0113, -1.0114, -1.0113])
|
||||
elif model_name == "git-large-msrvtt-qa":
|
||||
expected_slice_logits = torch.tensor([0.0130, 0.0134, 0.0131])
|
||||
elif model_name == "git-large-r":
|
||||
expected_slice_logits = torch.tensor([-1.1283, -1.1285, -1.1286])
|
||||
elif model_name == "git-large-r-coco":
|
||||
expected_slice_logits = torch.tensor([-0.9641, -0.9641, -0.9641])
|
||||
elif model_name == "git-large-r-textcaps":
|
||||
expected_slice_logits = torch.tensor([-1.1121, -1.1120, -1.1124])
|
||||
|
||||
assert torch.allclose(logits[0, -1, :3], expected_slice_logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
Loading…
Reference in New Issue
Block a user