
* Add usage example for DINOv2 * More explicit shape names * More verbose text * Moved example to Notes section * Indentation
10 KiB
DINOv2
DINOv2 is a vision foundation model that uses ViT as a feature extractor for multiple downstream tasks like image classification and depth estimation. It focuses on stabilizing and accelerating training through techniques like a faster memory-efficient attention, sequence packing, improved stochastic depth, Fully Sharded Data Parallel (FSDP), and model distillation.
You can find all the original DINOv2 checkpoints under the Dinov2 collection.
Tip
Click on the DINOv2 models in the right sidebar for more examples of how to apply DINOv2 to different vision tasks.
The example below demonstrates how to obtain an image embedding with [Pipeline
] or the [AutoModel
] class.
import torch
from transformers import pipeline
pipe = pipeline(
task="image-classification",
model="facebook/dinov2-small-imagenet1k-1-layer",
torch_dtype=torch.float16,
device=0
)
pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
import requests
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")
model = AutoModelForImageClassification.from_pretrained(
"facebook/dinov2-small-imagenet1k-1-layer",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
inputs = processor(images=image, return_tensors="pt")
logits = model(**inputs).logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.
The example below uses torchao to only quantize the weights to int4.
# pip install torchao
import requests
from transformers import TorchAoConfig, AutoImageProcessor, AutoModelForImageClassification
from torchao.quantization import Int4WeightOnlyConfig
from PIL import Image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-giant-imagenet1k-1-layer')
quant_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
model = AutoModelForImageClassification.from_pretrained(
'facebook/dinov2-giant-imagenet1k-1-layer',
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config
)
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
Notes
-
The example below shows how to split the output tensor into:
- one embedding for the whole image, commonly referred to as a
CLS
token, useful for classification and retrieval - a set of local embeddings, one for each
14x14
patch of the input image, useful for dense tasks, such as semantic segmentation
from transformers import AutoImageProcessor, AutoModel from PIL import Image import requests url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) print(image.height, image.width) # [480, 640] processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') model = AutoModel.from_pretrained('facebook/dinov2-base') patch_size = model.config.patch_size inputs = processor(images=image, return_tensors="pt") print(inputs.pixel_values.shape) # [1, 3, 224, 224] batch_size, rgb, img_height, img_width = inputs.pixel_values.shape num_patches_height, num_patches_width = img_height // patch_size, img_width // patch_size num_patches_flat = num_patches_height * num_patches_width outputs = model(**inputs) last_hidden_states = outputs[0] print(last_hidden_states.shape) # [1, 1 + 256, 768] assert last_hidden_states.shape == (batch_size, 1 + num_patches_flat, model.config.hidden_size) cls_token = last_hidden_states[:, 0, :] patch_features = last_hidden_states[:, 1:, :].unflatten(1, (num_patches_height, num_patches_width))
- one embedding for the whole image, commonly referred to as a
-
Use torch.jit.trace to speedup inference. However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
import torch from transformers import AutoImageProcessor, AutoModel from PIL import Image import requests url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') model = AutoModel.from_pretrained('facebook/dinov2-base') inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) last_hidden_states = outputs[0] # We have to force return_dict=False for tracing model.config.return_dict = False with torch.no_grad(): traced_model = torch.jit.trace(model, [inputs.pixel_values]) traced_outputs = traced_model(inputs.pixel_values) print((last_hidden_states - traced_outputs[0]).abs().max())
Dinov2Config
autodoc Dinov2Config
Dinov2Model
autodoc Dinov2Model - forward
Dinov2ForImageClassification
autodoc Dinov2ForImageClassification - forward
FlaxDinov2Model
autodoc FlaxDinov2Model - call
FlaxDinov2ForImageClassification
autodoc FlaxDinov2ForImageClassification - call