mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add usage example for DINOv2 (#37398)
* Add usage example for DINOv2 * More explicit shape names * More verbose text * Moved example to Notes section * Indentation
This commit is contained in:
parent
d20aa68193
commit
e94a4807df
@ -111,33 +111,68 @@ print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
|
||||
## Notes
|
||||
|
||||
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference. However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
|
||||
- The example below shows how to split the output tensor into:
|
||||
- one embedding for the whole image, commonly referred to as a `CLS` token,
|
||||
useful for classification and retrieval
|
||||
- a set of local embeddings, one for each `14x14` patch of the input image,
|
||||
useful for dense tasks, such as semantic segmentation
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from PIL import Image
|
||||
import requests
|
||||
```py
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
print(image.height, image.width) # [480, 640]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
||||
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
||||
patch_size = model.config.patch_size
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
print(inputs.pixel_values.shape) # [1, 3, 224, 224]
|
||||
batch_size, rgb, img_height, img_width = inputs.pixel_values.shape
|
||||
num_patches_height, num_patches_width = img_height // patch_size, img_width // patch_size
|
||||
num_patches_flat = num_patches_height * num_patches_width
|
||||
|
||||
outputs = model(**inputs)
|
||||
last_hidden_states = outputs[0]
|
||||
print(last_hidden_states.shape) # [1, 1 + 256, 768]
|
||||
assert last_hidden_states.shape == (batch_size, 1 + num_patches_flat, model.config.hidden_size)
|
||||
|
||||
cls_token = last_hidden_states[:, 0, :]
|
||||
patch_features = last_hidden_states[:, 1:, :].unflatten(1, (num_patches_height, num_patches_width))
|
||||
```
|
||||
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference.
|
||||
However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
||||
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
last_hidden_states = outputs[0]
|
||||
|
||||
# We have to force return_dict=False for tracing
|
||||
model.config.return_dict = False
|
||||
|
||||
with torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, [inputs.pixel_values])
|
||||
traced_outputs = traced_model(inputs.pixel_values)
|
||||
|
||||
print((last_hidden_states - traced_outputs[0]).abs().max())
|
||||
```
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
||||
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
last_hidden_states = outputs[0]
|
||||
|
||||
# We have to force return_dict=False for tracing
|
||||
model.config.return_dict = False
|
||||
|
||||
with torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, [inputs.pixel_values])
|
||||
traced_outputs = traced_model(inputs.pixel_values)
|
||||
|
||||
print((last_hidden_states - traced_outputs[0]).abs().max())
|
||||
```
|
||||
|
||||
## Dinov2Config
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user