Add usage example for DINOv2 (#37398)

* Add usage example for DINOv2

* More explicit shape names

* More verbose text

* Moved example to Notes section

* Indentation
This commit is contained in:
Federico Baldassarre 2025-05-01 17:54:22 +02:00 committed by GitHub
parent d20aa68193
commit e94a4807df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -111,33 +111,68 @@ print("Predicted class:", model.config.id2label[predicted_class_idx])
## Notes
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference. However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
- The example below shows how to split the output tensor into:
- one embedding for the whole image, commonly referred to as a `CLS` token,
useful for classification and retrieval
- a set of local embeddings, one for each `14x14` patch of the input image,
useful for dense tasks, such as semantic segmentation
```py
import torch
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
```py
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
print(image.height, image.width) # [480, 640]
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')
patch_size = model.config.patch_size
inputs = processor(images=image, return_tensors="pt")
print(inputs.pixel_values.shape) # [1, 3, 224, 224]
batch_size, rgb, img_height, img_width = inputs.pixel_values.shape
num_patches_height, num_patches_width = img_height // patch_size, img_width // patch_size
num_patches_flat = num_patches_height * num_patches_width
outputs = model(**inputs)
last_hidden_states = outputs[0]
print(last_hidden_states.shape) # [1, 1 + 256, 768]
assert last_hidden_states.shape == (batch_size, 1 + num_patches_flat, model.config.hidden_size)
cls_token = last_hidden_states[:, 0, :]
patch_features = last_hidden_states[:, 1:, :].unflatten(1, (num_patches_height, num_patches_width))
```
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference.
However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs[0]
# We have to force return_dict=False for tracing
model.config.return_dict = False
with torch.no_grad():
traced_model = torch.jit.trace(model, [inputs.pixel_values])
traced_outputs = traced_model(inputs.pixel_values)
print((last_hidden_states - traced_outputs[0]).abs().max())
```
```py
import torch
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs[0]
# We have to force return_dict=False for tracing
model.config.return_dict = False
with torch.no_grad():
traced_model = torch.jit.trace(model, [inputs.pixel_values])
traced_outputs = traced_model(inputs.pixel_values)
print((last_hidden_states - traced_outputs[0]).abs().max())
```
## Dinov2Config