transformers/docs/source/en/model_doc/vjepa2.md
Pavel Iakubovskii 9bec2654ed
Add V-JEPA for video classification model (#38788)
* adding model and conversion scripts

* add imports to test vjepa conversion

* fix imports and make conversion work

* fix computation for short side

* replace attention with library attention function

* cleanup more attention classes

* remove config overrides

* add test cases, fix some of the failing ones

* fix the model outputs

* fix outputs of the model per review

* fix too big model test case

* fix styling __init__.py

* fix initialization test

* remove all asserts per review

* update sorting unsorting logic as per feedback

* remove is_video per review

* remove another is_video segment

* remove unwanted stuff

* small fixes

* add docstrings for the model

* revert adding vjepa2 config here

* update styling

* add config docstrings (wip)

* fix dpr issue

* removed test failing issues

* update styles

* merge predictor configs into main config

* remove processing code, add video processor

* remove permute which is not necessary now

* fix styles

* updated vjepa2 to be in video_processing_auto

* update comment for preprocessing

* test integration test and fix the outputs

* update test values, change test to look at repeated frames for a given image

* add a simple video processing test

* refactoring pixel_values_videos and upload ckpts to original

* fix torch_fx test cases

* remove unused config

* add all config docstrings

* add more integration tests

* add basic doc

* revert unwanted styling changes

* working make fixup

* Fix model_type in config

* Add ForVideoClassification model

* update attention implementation to fit new hf standards

* fix the preprocessing logic, ensure it matches the original model

* remove use_rope logic, cleanup

* fix docstrings

* Further cleanup, update doc

* Fix model prefix

* fix get_vision_features

* VJEPA2Embeddings style refactor

* nit, style comment

* change modules default values

* Only `str` activation in config

* GradientCheckpointingLayer

* fixup

* fix conversion script

* Remove return_dict

* remove None return typehint

* Refactor VJEPA2Layer, remove use_SiLU

* Fix fx tests

* dpr -> drop_path_rates

* move *ModelOutput on top

* format docs bit

* update docs

* update docs

* update doc example

* remove prune_heads from model

* remove unused config params

* refactor embed signature

* Add vjepa to docs

* Fix config docstring

* attention head

* update defaults

* Update docs/source/en/model_doc/vjepa2.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/vjepa2.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix import

* Min refactoring

* Update HUB_SOURCE and HUB_REPO in conversion script

* Add missing headers

* VJEPA -> V-JEPA in docs

* Add image to doc

* fix style

* fix init weights

* change checkpoint name in modeling tests

* Initial cls head setup

* remove rop attention from head (not needed)

* remove swigluffn - not needed

* Add siglip layer

* Replace with siglip layer

* Rename Siglip - VJEPA2

* remove unused modules

* remove siglip mlp

* nit

* remove MLP

* Refactor head cross attention

* refactor VJEPA2HeadCrossAttentionLayer

* nit renaming

* fixup

* remove commented code

* Add cls head params to config

* depth from config

* move pooler + classifier  to the model

* Update for cls model signature

* move layers, rename a bit

* fix docs

* update weights init

* remove typehint for init

* add to auto-mapping

* enable tests

* Add conversion script

* fixup

* add to docs

* fix docs

* nit

* refactor for mapping

* clean

* Add integration test

* Fixing multi gpu test

* update not-split-modules

* update video cls test tolerance

* Increase test_inference_image tolerance

* Update no-split modules for multi gpu

* Apply suggestions from code review

* fixing multi-gpu

* fix docstring

* Add cls snippet to docs

* Update checkpoint
2025-06-13 17:56:15 +01:00

5.0 KiB

PyTorch SDPA FlashAttention

V-JEPA 2

V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.

drawing

You can find all original V-JEPA2 checkpoints under the V-JEPA 2 collection.

This model was contributed by koustuvs, yonigozlan and qubvel. The original code can be found here.

Usage example

The snippet below shows how to load the V-JEPA 2 model for feature extraction using the AutoModel class.

import torch
from torchcodec.decoders import VideoDecoder
import numpy as np

processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
model = AutoModel.from_pretrained(
    "facebook/vjepa2-vitl-fpc64-256",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa"
)

video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"

vr = VideoDecoder(video_url)
frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data  # T x C x H x W
video = processor(video, return_tensors="pt").to(model.device)
outputs = model(**video)

# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()`
encoder_outputs = outputs.last_hidden_state

# V-JEPA 2 predictor outputs
predictor_outputs = outputs.predictor_output.last_hidden_state

V-JEPA 2 can also be finetuned for video classification. In the following snippet, we show how use finetuned on Something-Something-V2 video classification model.

import torch
import numpy as np

from torchcodec.decoders import VideoDecoder
from transformers import AutoVideoProcessor, AutoModelForVideoClassification

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and video preprocessor
hf_repo = "facebook/vjepa2-vitl-fpc16-256-ssv2"

model = AutoModelForVideoClassification.from_pretrained(hf_repo).to(device)
processor = AutoVideoProcessor.from_pretrained(hf_repo)

# To load a video, sample the number of frames according to the model.
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, model.config.frames_per_clip, 8) # you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data  # frames x channels x height x width

# Preprocess and run inference
inputs = processor(video, return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model(**inputs)
logits = outputs.logits

print("Top 5 predicted class names:")
top5_indices = logits.topk(5).indices[0]
top5_probs = torch.softmax(logits, dim=-1).topk(5).values[0]
for idx, prob in zip(top5_indices, top5_probs):
    text_label = model.config.id2label[idx.item()]
    print(f" - {text_label}: {prob:.2f}")

VJEPA2Config

autodoc VJEPA2Config

VJEPA2Model

autodoc VJEPA2Model - forward

VJEPA2ForVideoClassification

autodoc VJEPA2ForVideoClassification - forward

VJEPA2VideoProcessor

autodoc VJEPA2VideoProcessor