mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add expected output to the sample code for ViTMSNForImageClassification
(#19183)
* chore: add expected output to the sample code. * add: imagenet-1k labels to the model config. * chore: apply code formatting. * chore: change the expected output.
This commit is contained in:
parent
368b649af6
commit
582d085bb2
@ -15,11 +15,13 @@
|
||||
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel
|
||||
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
@ -147,6 +149,13 @@ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
config = ViTMSNConfig()
|
||||
config.num_labels = 1000
|
||||
|
||||
repo_id = "datasets/huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
if "s16" in checkpoint_url:
|
||||
config.hidden_size = 384
|
||||
config.intermediate_size = 1536
|
||||
|
@ -632,6 +632,8 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> torch.manual_seed(2)
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
@ -644,6 +646,7 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_label = logits.argmax(-1).item()
|
||||
>>> print(model.config.id2label[predicted_label])
|
||||
Kerry blue terrier
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user