From 9bec2654ed5b4ac43e880dc7e3cb2c18aeae70a9 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Fri, 13 Jun 2025 17:56:15 +0100 Subject: [PATCH] Add V-JEPA for video classification model (#38788) * adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * Add ForVideoClassification model * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * attention head * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests * Initial cls head setup * remove rop attention from head (not needed) * remove swigluffn - not needed * Add siglip layer * Replace with siglip layer * Rename Siglip - VJEPA2 * remove unused modules * remove siglip mlp * nit * remove MLP * Refactor head cross attention * refactor VJEPA2HeadCrossAttentionLayer * nit renaming * fixup * remove commented code * Add cls head params to config * depth from config * move pooler + classifier to the model * Update for cls model signature * move layers, rename a bit * fix docs * update weights init * remove typehint for init * add to auto-mapping * enable tests * Add conversion script * fixup * add to docs * fix docs * nit * refactor for mapping * clean * Add integration test * Fixing multi gpu test * update not-split-modules * update video cls test tolerance * Increase test_inference_image tolerance * Update no-split modules for multi gpu * Apply suggestions from code review * fixing multi-gpu * fix docstring * Add cls snippet to docs * Update checkpoint --- docs/source/en/model_doc/vjepa2.md | 44 +- src/transformers/loss/loss_utils.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/vjepa2/configuration_vjepa2.py | 8 + .../vjepa2/convert_vjepa2_classifier_to_hf.py | 220 +++++++++ .../models/vjepa2/modeling_vjepa2.py | 448 ++++++++++++++++-- src/transformers/utils/fx.py | 3 + tests/models/vjepa2/test_modeling_vjepa2.py | 25 +- 8 files changed, 698 insertions(+), 52 deletions(-) create mode 100644 src/transformers/models/vjepa2/convert_vjepa2_classifier_to_hf.py diff --git a/docs/source/en/model_doc/vjepa2.md b/docs/source/en/model_doc/vjepa2.md index 5ad02ae274b..b16875339ed 100644 --- a/docs/source/en/model_doc/vjepa2.md +++ b/docs/source/en/model_doc/vjepa2.md @@ -38,7 +38,7 @@ This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yoni ## Usage example -The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class. +The snippet below shows how to load the V-JEPA 2 model for feature extraction using the `AutoModel` class. ```py import torch @@ -68,6 +68,43 @@ encoder_outputs = outputs.last_hidden_state predictor_outputs = outputs.predictor_output.last_hidden_state ``` +V-JEPA 2 can also be finetuned for video classification. In the following snippet, we show how use finetuned on Something-Something-V2 video classification model. + +```python +import torch +import numpy as np + +from torchcodec.decoders import VideoDecoder +from transformers import AutoVideoProcessor, AutoModelForVideoClassification + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Load model and video preprocessor +hf_repo = "facebook/vjepa2-vitl-fpc16-256-ssv2" + +model = AutoModelForVideoClassification.from_pretrained(hf_repo).to(device) +processor = AutoVideoProcessor.from_pretrained(hf_repo) + +# To load a video, sample the number of frames according to the model. +video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4" +vr = VideoDecoder(video_url) +frame_idx = np.arange(0, model.config.frames_per_clip, 8) # you can define more complex sampling strategy +video = vr.get_frames_at(indices=frame_idx).data # frames x channels x height x width + +# Preprocess and run inference +inputs = processor(video, return_tensors="pt").to(model.device) +with torch.no_grad(): + outputs = model(**inputs) +logits = outputs.logits + +print("Top 5 predicted class names:") +top5_indices = logits.topk(5).indices[0] +top5_probs = torch.softmax(logits, dim=-1).topk(5).values[0] +for idx, prob in zip(top5_indices, top5_probs): + text_label = model.config.id2label[idx.item()] + print(f" - {text_label}: {prob:.2f}") +``` + ## VJEPA2Config [[autodoc]] VJEPA2Config @@ -77,6 +114,11 @@ predictor_outputs = outputs.predictor_output.last_hidden_state [[autodoc]] VJEPA2Model - forward +## VJEPA2ForVideoClassification + +[[autodoc]] VJEPA2ForVideoClassification + - forward + ## VJEPA2VideoProcessor [[autodoc]] VJEPA2VideoProcessor diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 764d28d6f34..75c4cbf3451 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -150,6 +150,7 @@ LOSS_MAPPING = { "ForQuestionAnswering": ForQuestionAnsweringLoss, "ForSequenceClassification": ForSequenceClassificationLoss, "ForImageClassification": ForSequenceClassificationLoss, + "ForVideoClassification": ForSequenceClassificationLoss, "ForTokenClassification": ForTokenClassification, "ForSegmentation": ForSegmentationLoss, "ForObjectDetection": ForObjectDetectionLoss, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bcd56483c11..1b776b66ce3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -844,6 +844,7 @@ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("timesformer", "TimesformerForVideoClassification"), ("videomae", "VideoMAEForVideoClassification"), ("vivit", "VivitForVideoClassification"), + ("vjepa2", "VJEPA2ForVideoClassification"), ] ) diff --git a/src/transformers/models/vjepa2/configuration_vjepa2.py b/src/transformers/models/vjepa2/configuration_vjepa2.py index 4571b886021..1fd19c4d078 100644 --- a/src/transformers/models/vjepa2/configuration_vjepa2.py +++ b/src/transformers/models/vjepa2/configuration_vjepa2.py @@ -60,6 +60,10 @@ class VJEPA2Config(PretrainedConfig): `"relu"`, `"selu"` and `"gelu_new"` are supported. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for attentions. + num_pooler_layers (`int`, *optional*, defaults to 3): + The number of self-attention layers in the pooler. pred_hidden_size (`int`, *optional*, defaults to 384): Dimensionality of the predictor layers pred_num_attention_heads (`int`, *optional*, defaults to 12): @@ -107,6 +111,8 @@ class VJEPA2Config(PretrainedConfig): attention_probs_dropout_prob=0.0, hidden_act="gelu", initializer_range=0.02, + attention_dropout=0.0, + num_pooler_layers=3, # predictor params pred_hidden_size=384, pred_num_attention_heads=12, @@ -134,6 +140,8 @@ class VJEPA2Config(PretrainedConfig): self.hidden_act = hidden_act self.initializer_range = initializer_range self.image_size = crop_size + self.attention_dropout = attention_dropout + self.num_pooler_layers = num_pooler_layers # predictor params self.pred_hidden_size = pred_hidden_size self.pred_num_attention_heads = pred_num_attention_heads diff --git a/src/transformers/models/vjepa2/convert_vjepa2_classifier_to_hf.py b/src/transformers/models/vjepa2/convert_vjepa2_classifier_to_hf.py new file mode 100644 index 00000000000..4e3512f5f9f --- /dev/null +++ b/src/transformers/models/vjepa2/convert_vjepa2_classifier_to_hf.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import re + +import numpy as np +import torch +from decord import VideoReader +from huggingface_hub import HfApi, hf_hub_download + +from transformers import VJEPA2ForVideoClassification, VJEPA2VideoProcessor + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def get_video(): + path = hf_hub_download( + repo_id="nateraw/kinetics-mini", + filename="val/bowling/-WH-lxmGJVY_000005_000015.mp4", + repo_type="dataset", + ) + video_reader = VideoReader(path) + return video_reader + + +CLASSIFIERS = { + # Something-Something-v2 dataset + "vjepa2-vitl-fpc16-256-ssv2": { + "base_model": "facebook/vjepa2-vitl-fpc64-256", + "checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitl-16x2x3.pt", + "num_labels": 174, + "frames_per_clip": 16, + "dataset": "something-something-v2", + "result": (145, 0.30867, "Stuffing [something] into [something]"), + }, + "vjepa2-vitg-fpc64-384-ssv2": { + "base_model": "facebook/vjepa2-vitg-fpc64-384", + "checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt", + "frames_per_clip": 64, + "num_labels": 174, + "dataset": "something-something-v2", + "result": (112, 0.26408, "Putting [something] onto [something]"), + }, + # Diving48 dataset + "vjepa2-vitl-fpc32-256-diving48": { + "base_model": "facebook/vjepa2-vitl-fpc64-256", + "checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitl-256.pt", + "num_labels": 48, + "frames_per_clip": 32, + "dataset": "diving48", + "result": (35, 0.32875, "['Inward', '35som', 'NoTwis', 'TUCK']"), + }, + "vjepa2-vitg-fpc32-384-diving48": { + "base_model": "facebook/vjepa2-vitg-fpc64-384", + "checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitg-384-32x4x3.pt", + "frames_per_clip": 32, + "num_labels": 48, + "dataset": "diving48", + "result": (22, 0.35351, "['Forward', '25som', '2Twis', 'PIKE']"), + }, +} + +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"module.pooler.query_tokens": r"pooler.query_tokens", + r"module.pooler.cross_attention_block.norm(\d+).": r"pooler.cross_attention_layer.layer_norm\1.", + r"module.pooler.cross_attention_block.xattn.(q|k|v).": r"pooler.cross_attention_layer.cross_attn.\1_proj.", + r"module.pooler.cross_attention_block.mlp.fc(\d+).": r"pooler.cross_attention_layer.mlp.fc\1.", + r"module.pooler.blocks.(\d+).norm(\d+).": r"pooler.self_attention_layers.\1.layer_norm\2.", + r"module.pooler.blocks.(\d+).attn.(q|k|v).": r"pooler.self_attention_layers.\1.self_attn.\2_proj.", + r"module.pooler.blocks.(\d+).attn.proj.": r"pooler.self_attention_layers.\1.self_attn.out_proj.", + r"module.pooler.blocks.(\d+).mlp.fc(\d+).": r"pooler.self_attention_layers.\1.mlp.fc\2.", + r"module.linear.": r"classifier.", +} +# fmt: on + + +def get_id2label_mapping(dataset_name: str) -> dict[int, str]: + path = hf_hub_download( + repo_id="huggingface/label-files", + filename=f"{dataset_name}-id2label.json", + repo_type="dataset", + ) + with open(path, "r") as f: + id2label = json.load(f) + id2label = {int(k): v for k, v in id2label.items()} + return id2label + + +def split_qkv(state_dict): + state_dict = state_dict.copy() + keys = list(state_dict.keys()) + for key in keys: + if ".qkv." in key: + tensor = state_dict.pop(key) + q, k, v = torch.chunk(tensor, 3, dim=0) + state_dict[key.replace(".qkv.", ".q.")] = q + state_dict[key.replace(".qkv.", ".k.")] = k + state_dict[key.replace(".qkv.", ".v.")] = v + elif ".kv." in key: + tensor = state_dict.pop(key) + k, v = torch.chunk(tensor, 2, dim=0) + state_dict[key.replace(".kv.", ".k.")] = k + state_dict[key.replace(".kv.", ".v.")] = v + + return state_dict + + +def convert_old_keys_to_new_keys(state_dict): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + old_text = "\n".join(state_dict) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def main(args: argparse.Namespace): + model_params = CLASSIFIERS[args.model_name] + id2label = get_id2label_mapping(model_params["dataset"]) + + if not len(id2label) == model_params["num_labels"]: + raise ValueError( + f"Number of labels in id2label mapping ({len(id2label)}) does not " + f"match number of labels in model ({model_params['num_labels']})" + ) + + model = VJEPA2ForVideoClassification.from_pretrained( + model_params["base_model"], + num_labels=model_params["num_labels"], + id2label=id2label, + frames_per_clip=model_params["frames_per_clip"], + ) + processor = VJEPA2VideoProcessor.from_pretrained(model_params["base_model"]) + + # load and convert classifier checkpoint + checkpoint = torch.hub.load_state_dict_from_url(model_params["checkpoint"]) + state_dict = checkpoint["classifiers"][0] + + state_dict_qkv_split = split_qkv(state_dict) + key_mapping = convert_old_keys_to_new_keys(state_dict_qkv_split.keys()) + converted_state_dict2 = {key_mapping[k]: v for k, v in state_dict_qkv_split.items()} + + result = model.load_state_dict(converted_state_dict2, strict=False) + if result.unexpected_keys: + raise ValueError(f"Error loading state dict: {result.unexpected_keys}") + + if not args.skip_verification: + # get inputs + video_reader = get_video() + frame_indexes = np.arange(0, 128, 128 / model_params["frames_per_clip"]) + video = video_reader.get_batch(frame_indexes).asnumpy() + inputs = processor(video, return_tensors="pt").to(device) + + # run model + model.to(device).eval() + with torch.no_grad(): + outputs = model(**inputs) + + # compare results + probs = torch.softmax(outputs.logits, dim=-1) + top_prob, top_idx = probs.topk(1) + top_prob, top_idx = top_prob.item(), top_idx.item() + label = id2label[top_idx] + expected_id, expected_prob, expected_label = model_params["result"] + + if not top_idx == expected_id: + raise ValueError(f"Expected id {expected_id} but got {top_idx}") + if not label == expected_label: + raise ValueError(f"Expected label {expected_label} but got {label}") + if not np.isclose(top_prob, expected_prob, atol=1e-3): + raise ValueError(f"Expected prob {expected_prob} but got {top_prob}") + print("Verification passed") + + output_dir = os.path.join(args.base_dir, args.model_name) + model.save_pretrained(output_dir) + processor.save_pretrained(output_dir) + + if args.push_to_hub: + api = HfApi() + repo_id = f"{args.repo_org}/{args.model_name}" + if not api.repo_exists(repo_id): + api.create_repo(repo_id, repo_type="model") + api.upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--base_dir", type=str, default="converted_models/") + parser.add_argument("--repo_org", type=str, default="qubvel-hf") + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--skip_verification", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index 7a3a95b1298..8e13434307c 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -19,7 +20,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, dataclass +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging from .configuration_vjepa2 import VJEPA2Config @@ -536,17 +537,21 @@ class VJEPA2Encoder(nn.Module): ) -def apply_masks(x, masks) -> torch.Tensor: +def apply_masks(tensor: torch.Tensor, masks: List[torch.Tensor]) -> torch.Tensor: """ - :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] - :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + Args: + tensor (`torch.Tensor`): + Tensor of shape [batch_size, num_patches, feature_dim] + masks (`List[torch.Tensor]`): + List of tensors of shape [batch_size, num_patches] containing indices of patches to keep """ - all_x = [] - for m in masks: - mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) - all_x += [torch.gather(x, dim=1, index=mask_keep)] + all_masked_tensors = [] + for mask in masks: + mask = mask.to(tensor.device) + mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1)) + all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)] - return torch.cat(all_x, dim=0) + return torch.cat(all_masked_tensors, dim=0) class VJEPA2PredictorEmbeddings(nn.Module): @@ -649,13 +654,18 @@ class VJEPA2Predictor(nn.Module): self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True) def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None): + # gather position masks + argsort = argsort.to(position_masks.device) position_masks = torch.gather(position_masks, dim=1, index=argsort) - hidden_states = torch.gather( - hidden_states, - dim=1, - index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)), - ) + + # gather hidden states + argsort = argsort.to(hidden_states.device) + hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)) + hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort) + + # gather head mask if head_mask is not None and head_mask[0] is not None: + argsort = argsort.to(head_mask.device) head_mask = head_mask.permute(1, 0, 2, 3, 4) argsort_4d = ( argsort.unsqueeze(1) @@ -673,15 +683,14 @@ class VJEPA2Predictor(nn.Module): ) head_mask = torch.gather(head_mask, dim=4, index=argsort_5d) head_mask = head_mask.permute(1, 0, 2, 3, 4) + return hidden_states, position_masks, head_mask def unsort_tokens(self, hidden_states, argsort): + argsort = argsort.to(hidden_states.device) reverse_argsort = torch.argsort(argsort, dim=1) - hidden_states = torch.gather( - hidden_states, - dim=1, - index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)), - ) + reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)) + hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort) return hidden_states @can_return_tuple @@ -735,49 +744,304 @@ class VJEPA2Predictor(nn.Module): ) +class VJEPA2PoolerSelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: VJEPA2Config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class VJEPA2PoolerCrossAttention(nn.Module): + """It's different from other cross-attention layers, doesn't have output projection layer (o_proj)""" + + # in case of modular refactoring - o_proj can be replaces with nn.Identity() + + def __init__(self, config: VJEPA2Config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + queries: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_seq_length, embed_dim = queries.shape + kv_seq_length = keys.shape[1] + + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP +class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer): + def __init__(self, config: VJEPA2Config): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.self_attn = VJEPA2PoolerSelfAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer): + def __init__(self, config: VJEPA2Config): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.cross_attn = VJEPA2PoolerCrossAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size) + + def forward( + self, + queries: torch.Tensor, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + # Apply cross-attention + residual = queries + hidden_state = self.layer_norm1(hidden_state) + hidden_state, *attn_weights = self.cross_attn( + queries, + hidden_state, + hidden_state, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_state = residual + hidden_state + + # Apply MLP + residual = hidden_state + hidden_state = self.layer_norm2(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + + outputs = (hidden_state,) + if output_attentions: + outputs += tuple(attn_weights) + + return outputs + + +class VJEPA2AttentivePooler(nn.Module): + """Attentive Pooler""" + + def __init__(self, config: VJEPA2Config): + super().__init__() + self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config) + self.self_attention_layers = nn.ModuleList( + [VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)] + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + for layer in self.self_attention_layers: + hidden_state = layer(hidden_state, attention_mask=None)[0] + queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1) + hidden_state = self.cross_attention_layer(queries, hidden_state)[0] + return hidden_state.squeeze(1) + + @auto_docstring class VJEPA2PreTrainedModel(PreTrainedModel): config_class = VJEPA2Config base_model_prefix = "vjepa2" main_input_name = "pixel_values_videos" supports_gradient_checkpointing = True - _no_split_modules = ["VJEPA2Layer"] + _no_split_modules = [ + "VJEPA2Layer", + "VJEPA2PoolerSelfAttentionLayer", + "VJEPA2PoolerCrossAttentionLayer", + "VJEPA2PredictorEmbeddings", + ] _supports_sdpa = True _supports_flash_attn_2 = True - def _init_weights( - self, - module: Union[ - nn.Linear, - nn.Conv2d, - nn.LayerNorm, - VJEPA2Embeddings, - VJEPA2PredictorEmbeddings, - ], - ): + def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.weight.dtype) + + init_std = self.config.initializer_range + + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + def trunc_normal_f32_(weight, std): + data_float_32 = weight.data.to(torch.float32) + data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std) + weight.data = data_init.to(weight.dtype) + + if isinstance(module, VJEPA2AttentivePooler): + trunc_normal_f32_(module.query_tokens, std=init_std) + for i, layer in enumerate(module.self_attention_layers, 1): + std = init_std / (i**0.5) + trunc_normal_f32_(layer.self_attn.out_proj.weight, std=std) + trunc_normal_f32_(layer.mlp.fc2.weight, std=std) + std = init_std / (len(module.self_attention_layers) + 1) ** 0.5 + trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std) + elif isinstance(module, VJEPA2PredictorEmbeddings): + if module.zero_init_mask_tokens: + module.mask_tokens.data.zero_() + else: + trunc_normal_f32_(module.mask_tokens, std=init_std) + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): + trunc_normal_f32_(module.weight, std=init_std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) - elif isinstance(module, VJEPA2PredictorEmbeddings): - if not module.zero_init_mask_tokens: - module.mask_token = nn.init.trunc_normal_( - module.mask_token.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.mask_token.dtype) - else: - module.mask_tokens.data.zero_() def _convert_head_mask_to_5d(head_mask, num_hidden_layers): @@ -900,4 +1164,92 @@ class VJEPA2Model(VJEPA2PreTrainedModel): return encoder_output.last_hidden_state -__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"] +@auto_docstring( + custom_intro=""" + V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler). + """ +) +class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel): + def __init__(self, config: VJEPA2Config): + super().__init__(config) + + self.num_labels = config.num_labels + self.vjepa2 = VJEPA2Model(config) + + # Classifier head + self.pooler = VJEPA2AttentivePooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values_videos: torch.Tensor, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutput]: + r""" + pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`): + The input video pixels which is processed by VJEPA2VideoProcessor. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> import torch + >>> import numpy as np + >>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification + + >>> device = "cuda" + + >>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2") + >>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device) + + >>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB + >>> inputs = video_processor(video, return_tensors="pt").to(device) + + >>> # For inference + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> logits = outputs.logits + + >>> predicted_label = logits.argmax(-1).item() + >>> print(model.config.id2label[predicted_label]) + + >>> # For training + >>> labels = torch.ones(1, dtype=torch.long, device=device) + >>> loss = model(**inputs, labels=labels).loss + + ```""" + + outputs = self.vjepa2( + pixel_values_videos=pixel_values_videos, + skip_predictor=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = outputs.last_hidden_state + pooler_output = self.pooler(last_hidden_state) + logits = self.classifier(pooler_output) + + loss = None + if labels is not None: + loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 884dab4720b..07015927265 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -56,6 +56,7 @@ from ..models.auto.modeling_auto import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) @@ -194,6 +195,7 @@ _SPECIAL_SUPPORTED_MODELS = [ "TrOCRDecoder", "PeftModelForCausalLM", "PeftModelForSeq2SeqLM", + "VJEPA2ForVideoClassification", # TODO: add support for them as it should be quite easy to do so (small blocking issues). # XLNetForQuestionAnswering, ] @@ -904,6 +906,7 @@ class HFTracer(Tracer): *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES), *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES), *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES), + *get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES), *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES), ]: diff --git a/tests/models/vjepa2/test_modeling_vjepa2.py b/tests/models/vjepa2/test_modeling_vjepa2.py index 8a4b55ad6aa..5a38962771c 100644 --- a/tests/models/vjepa2/test_modeling_vjepa2.py +++ b/tests/models/vjepa2/test_modeling_vjepa2.py @@ -40,7 +40,7 @@ if is_torch_available(): import torch from torch import nn - from transformers import VJEPA2Model + from transformers import VJEPA2ForVideoClassification, VJEPA2Model if is_vision_available(): @@ -153,7 +153,7 @@ class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): test_torch_exportable = True - all_model_classes = (VJEPA2Model,) if is_torch_available() else () + all_model_classes = (VJEPA2Model, VJEPA2ForVideoClassification) if is_torch_available() else () fx_compatible = True @@ -267,7 +267,7 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase): [[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]], device=torch_device, ) - torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=8e-2, atol=8e-2) @slow def test_inference_video(self): @@ -343,3 +343,22 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase): # verify the last hidden states expected_shape = torch.Size((1, num_masks, 1024)) self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape) + + @slow + def test_video_classification(self): + checkpoint = "facebook/vjepa2-vitl-fpc16-256-ssv2" + + model = VJEPA2ForVideoClassification.from_pretrained(checkpoint).to(torch_device) + video_processor = AutoVideoProcessor.from_pretrained(checkpoint) + + sample_video = np.ones((16, 3, 256, 256)) + inputs = video_processor(sample_video, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + self.assertEqual(outputs.logits.shape, (1, 174)) + + expected_logits = torch.tensor([0.8814, -0.1195, -0.6389], device=torch_device) + resulted_logits = outputs.logits[0, 100:103] + torch.testing.assert_close(resulted_logits, expected_logits, rtol=1e-2, atol=1e-2)