
* adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * Add ForVideoClassification model * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * attention head * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests * Initial cls head setup * remove rop attention from head (not needed) * remove swigluffn - not needed * Add siglip layer * Replace with siglip layer * Rename Siglip - VJEPA2 * remove unused modules * remove siglip mlp * nit * remove MLP * Refactor head cross attention * refactor VJEPA2HeadCrossAttentionLayer * nit renaming * fixup * remove commented code * Add cls head params to config * depth from config * move pooler + classifier to the model * Update for cls model signature * move layers, rename a bit * fix docs * update weights init * remove typehint for init * add to auto-mapping * enable tests * Add conversion script * fixup * add to docs * fix docs * nit * refactor for mapping * clean * Add integration test * Fixing multi gpu test * update not-split-modules * update video cls test tolerance * Increase test_inference_image tolerance * Update no-split modules for multi gpu * Apply suggestions from code review * fixing multi-gpu * fix docstring * Add cls snippet to docs * Update checkpoint
5.0 KiB
V-JEPA 2
V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.

You can find all original V-JEPA2 checkpoints under the V-JEPA 2 collection.
This model was contributed by koustuvs, yonigozlan and qubvel. The original code can be found here.
Usage example
The snippet below shows how to load the V-JEPA 2 model for feature extraction using the AutoModel
class.
import torch
from torchcodec.decoders import VideoDecoder
import numpy as np
processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
model = AutoModel.from_pretrained(
"facebook/vjepa2-vitl-fpc64-256",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
video = processor(video, return_tensors="pt").to(model.device)
outputs = model(**video)
# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()`
encoder_outputs = outputs.last_hidden_state
# V-JEPA 2 predictor outputs
predictor_outputs = outputs.predictor_output.last_hidden_state
V-JEPA 2 can also be finetuned for video classification. In the following snippet, we show how use finetuned on Something-Something-V2 video classification model.
import torch
import numpy as np
from torchcodec.decoders import VideoDecoder
from transformers import AutoVideoProcessor, AutoModelForVideoClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and video preprocessor
hf_repo = "facebook/vjepa2-vitl-fpc16-256-ssv2"
model = AutoModelForVideoClassification.from_pretrained(hf_repo).to(device)
processor = AutoVideoProcessor.from_pretrained(hf_repo)
# To load a video, sample the number of frames according to the model.
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, model.config.frames_per_clip, 8) # you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data # frames x channels x height x width
# Preprocess and run inference
inputs = processor(video, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
print("Top 5 predicted class names:")
top5_indices = logits.topk(5).indices[0]
top5_probs = torch.softmax(logits, dim=-1).topk(5).values[0]
for idx, prob in zip(top5_indices, top5_probs):
text_label = model.config.id2label[idx.item()]
print(f" - {text_label}: {prob:.2f}")
VJEPA2Config
autodoc VJEPA2Config
VJEPA2Model
autodoc VJEPA2Model - forward
VJEPA2ForVideoClassification
autodoc VJEPA2ForVideoClassification - forward
VJEPA2VideoProcessor
autodoc VJEPA2VideoProcessor