Add V-JEPA for video classification model (#38788)

* adding model and conversion scripts

* add imports to test vjepa conversion

* fix imports and make conversion work

* fix computation for short side

* replace attention with library attention function

* cleanup more attention classes

* remove config overrides

* add test cases, fix some of the failing ones

* fix the model outputs

* fix outputs of the model per review

* fix too big model test case

* fix styling __init__.py

* fix initialization test

* remove all asserts per review

* update sorting unsorting logic as per feedback

* remove is_video per review

* remove another is_video segment

* remove unwanted stuff

* small fixes

* add docstrings for the model

* revert adding vjepa2 config here

* update styling

* add config docstrings (wip)

* fix dpr issue

* removed test failing issues

* update styles

* merge predictor configs into main config

* remove processing code, add video processor

* remove permute which is not necessary now

* fix styles

* updated vjepa2 to be in video_processing_auto

* update comment for preprocessing

* test integration test and fix the outputs

* update test values, change test to look at repeated frames for a given image

* add a simple video processing test

* refactoring pixel_values_videos and upload ckpts to original

* fix torch_fx test cases

* remove unused config

* add all config docstrings

* add more integration tests

* add basic doc

* revert unwanted styling changes

* working make fixup

* Fix model_type in config

* Add ForVideoClassification model

* update attention implementation to fit new hf standards

* fix the preprocessing logic, ensure it matches the original model

* remove use_rope logic, cleanup

* fix docstrings

* Further cleanup, update doc

* Fix model prefix

* fix get_vision_features

* VJEPA2Embeddings style refactor

* nit, style comment

* change modules default values

* Only `str` activation in config

* GradientCheckpointingLayer

* fixup

* fix conversion script

* Remove return_dict

* remove None return typehint

* Refactor VJEPA2Layer, remove use_SiLU

* Fix fx tests

* dpr -> drop_path_rates

* move *ModelOutput on top

* format docs bit

* update docs

* update docs

* update doc example

* remove prune_heads from model

* remove unused config params

* refactor embed signature

* Add vjepa to docs

* Fix config docstring

* attention head

* update defaults

* Update docs/source/en/model_doc/vjepa2.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/vjepa2.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix import

* Min refactoring

* Update HUB_SOURCE and HUB_REPO in conversion script

* Add missing headers

* VJEPA -> V-JEPA in docs

* Add image to doc

* fix style

* fix init weights

* change checkpoint name in modeling tests

* Initial cls head setup

* remove rop attention from head (not needed)

* remove swigluffn - not needed

* Add siglip layer

* Replace with siglip layer

* Rename Siglip - VJEPA2

* remove unused modules

* remove siglip mlp

* nit

* remove MLP

* Refactor head cross attention

* refactor VJEPA2HeadCrossAttentionLayer

* nit renaming

* fixup

* remove commented code

* Add cls head params to config

* depth from config

* move pooler + classifier  to the model

* Update for cls model signature

* move layers, rename a bit

* fix docs

* update weights init

* remove typehint for init

* add to auto-mapping

* enable tests

* Add conversion script

* fixup

* add to docs

* fix docs

* nit

* refactor for mapping

* clean

* Add integration test

* Fixing multi gpu test

* update not-split-modules

* update video cls test tolerance

* Increase test_inference_image tolerance

* Update no-split modules for multi gpu

* Apply suggestions from code review

* fixing multi-gpu

* fix docstring

* Add cls snippet to docs

* Update checkpoint
This commit is contained in:
Pavel Iakubovskii 2025-06-13 17:56:15 +01:00 committed by GitHub
parent 2ff964bcb4
commit 9bec2654ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 698 additions and 52 deletions

View File

@ -38,7 +38,7 @@ This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yoni
## Usage example
The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
The snippet below shows how to load the V-JEPA 2 model for feature extraction using the `AutoModel` class.
```py
import torch
@ -68,6 +68,43 @@ encoder_outputs = outputs.last_hidden_state
predictor_outputs = outputs.predictor_output.last_hidden_state
```
V-JEPA 2 can also be finetuned for video classification. In the following snippet, we show how use finetuned on Something-Something-V2 video classification model.
```python
import torch
import numpy as np
from torchcodec.decoders import VideoDecoder
from transformers import AutoVideoProcessor, AutoModelForVideoClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and video preprocessor
hf_repo = "facebook/vjepa2-vitl-fpc16-256-ssv2"
model = AutoModelForVideoClassification.from_pretrained(hf_repo).to(device)
processor = AutoVideoProcessor.from_pretrained(hf_repo)
# To load a video, sample the number of frames according to the model.
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, model.config.frames_per_clip, 8) # you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data # frames x channels x height x width
# Preprocess and run inference
inputs = processor(video, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
print("Top 5 predicted class names:")
top5_indices = logits.topk(5).indices[0]
top5_probs = torch.softmax(logits, dim=-1).topk(5).values[0]
for idx, prob in zip(top5_indices, top5_probs):
text_label = model.config.id2label[idx.item()]
print(f" - {text_label}: {prob:.2f}")
```
## VJEPA2Config
[[autodoc]] VJEPA2Config
@ -77,6 +114,11 @@ predictor_outputs = outputs.predictor_output.last_hidden_state
[[autodoc]] VJEPA2Model
- forward
## VJEPA2ForVideoClassification
[[autodoc]] VJEPA2ForVideoClassification
- forward
## VJEPA2VideoProcessor
[[autodoc]] VJEPA2VideoProcessor

View File

@ -150,6 +150,7 @@ LOSS_MAPPING = {
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForImageClassification": ForSequenceClassificationLoss,
"ForVideoClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
"ForSegmentation": ForSegmentationLoss,
"ForObjectDetection": ForObjectDetectionLoss,

View File

@ -844,6 +844,7 @@ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("timesformer", "TimesformerForVideoClassification"),
("videomae", "VideoMAEForVideoClassification"),
("vivit", "VivitForVideoClassification"),
("vjepa2", "VJEPA2ForVideoClassification"),
]
)

View File

@ -60,6 +60,10 @@ class VJEPA2Config(PretrainedConfig):
`"relu"`, `"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for attentions.
num_pooler_layers (`int`, *optional*, defaults to 3):
The number of self-attention layers in the pooler.
pred_hidden_size (`int`, *optional*, defaults to 384):
Dimensionality of the predictor layers
pred_num_attention_heads (`int`, *optional*, defaults to 12):
@ -107,6 +111,8 @@ class VJEPA2Config(PretrainedConfig):
attention_probs_dropout_prob=0.0,
hidden_act="gelu",
initializer_range=0.02,
attention_dropout=0.0,
num_pooler_layers=3,
# predictor params
pred_hidden_size=384,
pred_num_attention_heads=12,
@ -134,6 +140,8 @@ class VJEPA2Config(PretrainedConfig):
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.image_size = crop_size
self.attention_dropout = attention_dropout
self.num_pooler_layers = num_pooler_layers
# predictor params
self.pred_hidden_size = pred_hidden_size
self.pred_num_attention_heads = pred_num_attention_heads

View File

@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import numpy as np
import torch
from decord import VideoReader
from huggingface_hub import HfApi, hf_hub_download
from transformers import VJEPA2ForVideoClassification, VJEPA2VideoProcessor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_video():
path = hf_hub_download(
repo_id="nateraw/kinetics-mini",
filename="val/bowling/-WH-lxmGJVY_000005_000015.mp4",
repo_type="dataset",
)
video_reader = VideoReader(path)
return video_reader
CLASSIFIERS = {
# Something-Something-v2 dataset
"vjepa2-vitl-fpc16-256-ssv2": {
"base_model": "facebook/vjepa2-vitl-fpc64-256",
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitl-16x2x3.pt",
"num_labels": 174,
"frames_per_clip": 16,
"dataset": "something-something-v2",
"result": (145, 0.30867, "Stuffing [something] into [something]"),
},
"vjepa2-vitg-fpc64-384-ssv2": {
"base_model": "facebook/vjepa2-vitg-fpc64-384",
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt",
"frames_per_clip": 64,
"num_labels": 174,
"dataset": "something-something-v2",
"result": (112, 0.26408, "Putting [something] onto [something]"),
},
# Diving48 dataset
"vjepa2-vitl-fpc32-256-diving48": {
"base_model": "facebook/vjepa2-vitl-fpc64-256",
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitl-256.pt",
"num_labels": 48,
"frames_per_clip": 32,
"dataset": "diving48",
"result": (35, 0.32875, "['Inward', '35som', 'NoTwis', 'TUCK']"),
},
"vjepa2-vitg-fpc32-384-diving48": {
"base_model": "facebook/vjepa2-vitg-fpc64-384",
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitg-384-32x4x3.pt",
"frames_per_clip": 32,
"num_labels": 48,
"dataset": "diving48",
"result": (22, 0.35351, "['Forward', '25som', '2Twis', 'PIKE']"),
},
}
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"module.pooler.query_tokens": r"pooler.query_tokens",
r"module.pooler.cross_attention_block.norm(\d+).": r"pooler.cross_attention_layer.layer_norm\1.",
r"module.pooler.cross_attention_block.xattn.(q|k|v).": r"pooler.cross_attention_layer.cross_attn.\1_proj.",
r"module.pooler.cross_attention_block.mlp.fc(\d+).": r"pooler.cross_attention_layer.mlp.fc\1.",
r"module.pooler.blocks.(\d+).norm(\d+).": r"pooler.self_attention_layers.\1.layer_norm\2.",
r"module.pooler.blocks.(\d+).attn.(q|k|v).": r"pooler.self_attention_layers.\1.self_attn.\2_proj.",
r"module.pooler.blocks.(\d+).attn.proj.": r"pooler.self_attention_layers.\1.self_attn.out_proj.",
r"module.pooler.blocks.(\d+).mlp.fc(\d+).": r"pooler.self_attention_layers.\1.mlp.fc\2.",
r"module.linear.": r"classifier.",
}
# fmt: on
def get_id2label_mapping(dataset_name: str) -> dict[int, str]:
path = hf_hub_download(
repo_id="huggingface/label-files",
filename=f"{dataset_name}-id2label.json",
repo_type="dataset",
)
with open(path, "r") as f:
id2label = json.load(f)
id2label = {int(k): v for k, v in id2label.items()}
return id2label
def split_qkv(state_dict):
state_dict = state_dict.copy()
keys = list(state_dict.keys())
for key in keys:
if ".qkv." in key:
tensor = state_dict.pop(key)
q, k, v = torch.chunk(tensor, 3, dim=0)
state_dict[key.replace(".qkv.", ".q.")] = q
state_dict[key.replace(".qkv.", ".k.")] = k
state_dict[key.replace(".qkv.", ".v.")] = v
elif ".kv." in key:
tensor = state_dict.pop(key)
k, v = torch.chunk(tensor, 2, dim=0)
state_dict[key.replace(".kv.", ".k.")] = k
state_dict[key.replace(".kv.", ".v.")] = v
return state_dict
def convert_old_keys_to_new_keys(state_dict):
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
the key mappings.
"""
output_dict = {}
old_text = "\n".join(state_dict)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
def main(args: argparse.Namespace):
model_params = CLASSIFIERS[args.model_name]
id2label = get_id2label_mapping(model_params["dataset"])
if not len(id2label) == model_params["num_labels"]:
raise ValueError(
f"Number of labels in id2label mapping ({len(id2label)}) does not "
f"match number of labels in model ({model_params['num_labels']})"
)
model = VJEPA2ForVideoClassification.from_pretrained(
model_params["base_model"],
num_labels=model_params["num_labels"],
id2label=id2label,
frames_per_clip=model_params["frames_per_clip"],
)
processor = VJEPA2VideoProcessor.from_pretrained(model_params["base_model"])
# load and convert classifier checkpoint
checkpoint = torch.hub.load_state_dict_from_url(model_params["checkpoint"])
state_dict = checkpoint["classifiers"][0]
state_dict_qkv_split = split_qkv(state_dict)
key_mapping = convert_old_keys_to_new_keys(state_dict_qkv_split.keys())
converted_state_dict2 = {key_mapping[k]: v for k, v in state_dict_qkv_split.items()}
result = model.load_state_dict(converted_state_dict2, strict=False)
if result.unexpected_keys:
raise ValueError(f"Error loading state dict: {result.unexpected_keys}")
if not args.skip_verification:
# get inputs
video_reader = get_video()
frame_indexes = np.arange(0, 128, 128 / model_params["frames_per_clip"])
video = video_reader.get_batch(frame_indexes).asnumpy()
inputs = processor(video, return_tensors="pt").to(device)
# run model
model.to(device).eval()
with torch.no_grad():
outputs = model(**inputs)
# compare results
probs = torch.softmax(outputs.logits, dim=-1)
top_prob, top_idx = probs.topk(1)
top_prob, top_idx = top_prob.item(), top_idx.item()
label = id2label[top_idx]
expected_id, expected_prob, expected_label = model_params["result"]
if not top_idx == expected_id:
raise ValueError(f"Expected id {expected_id} but got {top_idx}")
if not label == expected_label:
raise ValueError(f"Expected label {expected_label} but got {label}")
if not np.isclose(top_prob, expected_prob, atol=1e-3):
raise ValueError(f"Expected prob {expected_prob} but got {top_prob}")
print("Verification passed")
output_dir = os.path.join(args.base_dir, args.model_name)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
if args.push_to_hub:
api = HfApi()
repo_id = f"{args.repo_org}/{args.model_name}"
if not api.repo_exists(repo_id):
api.create_repo(repo_id, repo_type="model")
api.upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--base_dir", type=str, default="converted_models/")
parser.add_argument("--repo_org", type=str, default="qubvel-hf")
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--skip_verification", action="store_true")
args = parser.parse_args()
main(args)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -19,7 +20,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, dataclass
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from .configuration_vjepa2 import VJEPA2Config
@ -536,17 +537,21 @@ class VJEPA2Encoder(nn.Module):
)
def apply_masks(x, masks) -> torch.Tensor:
def apply_masks(tensor: torch.Tensor, masks: List[torch.Tensor]) -> torch.Tensor:
"""
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
Args:
tensor (`torch.Tensor`):
Tensor of shape [batch_size, num_patches, feature_dim]
masks (`List[torch.Tensor]`):
List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
"""
all_x = []
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.gather(x, dim=1, index=mask_keep)]
all_masked_tensors = []
for mask in masks:
mask = mask.to(tensor.device)
mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1))
all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)]
return torch.cat(all_x, dim=0)
return torch.cat(all_masked_tensors, dim=0)
class VJEPA2PredictorEmbeddings(nn.Module):
@ -649,13 +654,18 @@ class VJEPA2Predictor(nn.Module):
self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
# gather position masks
argsort = argsort.to(position_masks.device)
position_masks = torch.gather(position_masks, dim=1, index=argsort)
hidden_states = torch.gather(
hidden_states,
dim=1,
index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
)
# gather hidden states
argsort = argsort.to(hidden_states.device)
hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort)
# gather head mask
if head_mask is not None and head_mask[0] is not None:
argsort = argsort.to(head_mask.device)
head_mask = head_mask.permute(1, 0, 2, 3, 4)
argsort_4d = (
argsort.unsqueeze(1)
@ -673,15 +683,14 @@ class VJEPA2Predictor(nn.Module):
)
head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
head_mask = head_mask.permute(1, 0, 2, 3, 4)
return hidden_states, position_masks, head_mask
def unsort_tokens(self, hidden_states, argsort):
argsort = argsort.to(hidden_states.device)
reverse_argsort = torch.argsort(argsort, dim=1)
hidden_states = torch.gather(
hidden_states,
dim=1,
index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
)
reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort)
return hidden_states
@can_return_tuple
@ -735,49 +744,304 @@ class VJEPA2Predictor(nn.Module):
)
class VJEPA2PoolerSelfAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: VJEPA2Config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class VJEPA2PoolerCrossAttention(nn.Module):
"""It's different from other cross-attention layers, doesn't have output projection layer (o_proj)"""
# in case of modular refactoring - o_proj can be replaces with nn.Identity()
def __init__(self, config: VJEPA2Config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
queries: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_seq_length, embed_dim = queries.shape
kv_seq_length = keys.shape[1]
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
# Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP
class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer):
def __init__(self, config: VJEPA2Config):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.self_attn = VJEPA2PoolerSelfAttention(config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, ...]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer):
def __init__(self, config: VJEPA2Config):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = VJEPA2PoolerCrossAttention(config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
def forward(
self,
queries: torch.Tensor,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
# Apply cross-attention
residual = queries
hidden_state = self.layer_norm1(hidden_state)
hidden_state, *attn_weights = self.cross_attn(
queries,
hidden_state,
hidden_state,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_state = residual + hidden_state
# Apply MLP
residual = hidden_state
hidden_state = self.layer_norm2(hidden_state)
hidden_state = self.mlp(hidden_state)
hidden_state = residual + hidden_state
outputs = (hidden_state,)
if output_attentions:
outputs += tuple(attn_weights)
return outputs
class VJEPA2AttentivePooler(nn.Module):
"""Attentive Pooler"""
def __init__(self, config: VJEPA2Config):
super().__init__()
self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config)
self.self_attention_layers = nn.ModuleList(
[VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)]
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
for layer in self.self_attention_layers:
hidden_state = layer(hidden_state, attention_mask=None)[0]
queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1)
hidden_state = self.cross_attention_layer(queries, hidden_state)[0]
return hidden_state.squeeze(1)
@auto_docstring
class VJEPA2PreTrainedModel(PreTrainedModel):
config_class = VJEPA2Config
base_model_prefix = "vjepa2"
main_input_name = "pixel_values_videos"
supports_gradient_checkpointing = True
_no_split_modules = ["VJEPA2Layer"]
_no_split_modules = [
"VJEPA2Layer",
"VJEPA2PoolerSelfAttentionLayer",
"VJEPA2PoolerCrossAttentionLayer",
"VJEPA2PredictorEmbeddings",
]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(
self,
module: Union[
nn.Linear,
nn.Conv2d,
nn.LayerNorm,
VJEPA2Embeddings,
VJEPA2PredictorEmbeddings,
],
):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.weight.dtype)
init_std = self.config.initializer_range
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
def trunc_normal_f32_(weight, std):
data_float_32 = weight.data.to(torch.float32)
data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std)
weight.data = data_init.to(weight.dtype)
if isinstance(module, VJEPA2AttentivePooler):
trunc_normal_f32_(module.query_tokens, std=init_std)
for i, layer in enumerate(module.self_attention_layers, 1):
std = init_std / (i**0.5)
trunc_normal_f32_(layer.self_attn.out_proj.weight, std=std)
trunc_normal_f32_(layer.mlp.fc2.weight, std=std)
std = init_std / (len(module.self_attention_layers) + 1) ** 0.5
trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std)
elif isinstance(module, VJEPA2PredictorEmbeddings):
if module.zero_init_mask_tokens:
module.mask_tokens.data.zero_()
else:
trunc_normal_f32_(module.mask_tokens, std=init_std)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_f32_(module.weight, std=init_std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, VJEPA2PredictorEmbeddings):
if not module.zero_init_mask_tokens:
module.mask_token = nn.init.trunc_normal_(
module.mask_token.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.mask_token.dtype)
else:
module.mask_tokens.data.zero_()
def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
@ -900,4 +1164,92 @@ class VJEPA2Model(VJEPA2PreTrainedModel):
return encoder_output.last_hidden_state
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"]
@auto_docstring(
custom_intro="""
V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
"""
)
class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
def __init__(self, config: VJEPA2Config):
super().__init__(config)
self.num_labels = config.num_labels
self.vjepa2 = VJEPA2Model(config)
# Classifier head
self.pooler = VJEPA2AttentivePooler(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values_videos: torch.Tensor,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutput]:
r"""
pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`):
The input video pixels which is processed by VJEPA2VideoProcessor.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Examples:
```python
>>> import torch
>>> import numpy as np
>>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification
>>> device = "cuda"
>>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
>>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)
>>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB
>>> inputs = video_processor(video, return_tensors="pt").to(device)
>>> # For inference
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits = outputs.logits
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
>>> # For training
>>> labels = torch.ones(1, dtype=torch.long, device=device)
>>> loss = model(**inputs, labels=labels).loss
```"""
outputs = self.vjepa2(
pixel_values_videos=pixel_values_videos,
skip_predictor=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = outputs.last_hidden_state
pooler_output = self.pooler(last_hidden_state)
logits = self.classifier(pooler_output)
loss = None
if labels is not None:
loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"]

View File

@ -56,6 +56,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
@ -194,6 +195,7 @@ _SPECIAL_SUPPORTED_MODELS = [
"TrOCRDecoder",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM",
"VJEPA2ForVideoClassification",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
@ -904,6 +906,7 @@ class HFTracer(Tracer):
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
]:

View File

@ -40,7 +40,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import VJEPA2Model
from transformers import VJEPA2ForVideoClassification, VJEPA2Model
if is_vision_available():
@ -153,7 +153,7 @@ class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torch_exportable = True
all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
all_model_classes = (VJEPA2Model, VJEPA2ForVideoClassification) if is_torch_available() else ()
fx_compatible = True
@ -267,7 +267,7 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase):
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
device=torch_device,
)
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=8e-2, atol=8e-2)
@slow
def test_inference_video(self):
@ -343,3 +343,22 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase):
# verify the last hidden states
expected_shape = torch.Size((1, num_masks, 1024))
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
@slow
def test_video_classification(self):
checkpoint = "facebook/vjepa2-vitl-fpc16-256-ssv2"
model = VJEPA2ForVideoClassification.from_pretrained(checkpoint).to(torch_device)
video_processor = AutoVideoProcessor.from_pretrained(checkpoint)
sample_video = np.ones((16, 3, 256, 256))
inputs = video_processor(sample_video, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
self.assertEqual(outputs.logits.shape, (1, 174))
expected_logits = torch.tensor([0.8814, -0.1195, -0.6389], device=torch_device)
resulted_logits = outputs.logits[0, 100:103]
torch.testing.assert_close(resulted_logits, expected_logits, rtol=1e-2, atol=1e-2)