Add expected output to the sample code for ViTMSNForImageClassification (#19183)

* chore: add expected output to the sample code.

* add: imagenet-1k labels to the model config.

* chore: apply code formatting.

* chore: change the expected output.
This commit is contained in:
Sayak Paul 2022-09-30 18:55:41 +05:30 committed by GitHub
parent 368b649af6
commit 582d085bb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 0 deletions

View File

@ -15,11 +15,13 @@
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
import argparse
import json
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -147,6 +149,13 @@ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config = ViTMSNConfig()
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if "s16" in checkpoint_url:
config.hidden_size = 384
config.intermediate_size = 1536

View File

@ -632,6 +632,8 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
>>> from PIL import Image
>>> import requests
>>> torch.manual_seed(2)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
@ -644,6 +646,7 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
Kerry blue terrier
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict