Add V-JEPA 2 (#38746)

* adding model and conversion scripts

* add imports to test vjepa conversion

* fix imports and make conversion work

* fix computation for short side

* replace attention with library attention function

* cleanup more attention classes

* remove config overrides

* add test cases, fix some of the failing ones

* fix the model outputs

* fix outputs of the model per review

* fix too big model test case

* fix styling __init__.py

* fix initialization test

* remove all asserts per review

* update sorting unsorting logic as per feedback

* remove is_video per review

* remove another is_video segment

* remove unwanted stuff

* small fixes

* add docstrings for the model

* revert adding vjepa2 config here

* update styling

* add config docstrings (wip)

* fix dpr issue

* removed test failing issues

* update styles

* merge predictor configs into main config

* remove processing code, add video processor

* remove permute which is not necessary now

* fix styles

* updated vjepa2 to be in video_processing_auto

* update comment for preprocessing

* test integration test and fix the outputs

* update test values, change test to look at repeated frames for a given image

* add a simple video processing test

* refactoring pixel_values_videos and upload ckpts to original

* fix torch_fx test cases

* remove unused config

* add all config docstrings

* add more integration tests

* add basic doc

* revert unwanted styling changes

* working make fixup

* Fix model_type in config

* update attention implementation to fit new hf standards

* fix the preprocessing logic, ensure it matches the original model

* remove use_rope logic, cleanup

* fix docstrings

* Further cleanup, update doc

* Fix model prefix

* fix get_vision_features

* VJEPA2Embeddings style refactor

* nit, style comment

* change modules default values

* Only `str` activation in config

* GradientCheckpointingLayer

* fixup

* fix conversion script

* Remove return_dict

* remove None return typehint

* Refactor VJEPA2Layer, remove use_SiLU

* Fix fx tests

* dpr -> drop_path_rates

* move *ModelOutput on top

* format docs bit

* update docs

* update docs

* update doc example

* remove prune_heads from model

* remove unused config params

* refactor embed signature

* Add vjepa to docs

* Fix config docstring

* update defaults

* Update docs/source/en/model_doc/vjepa2.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/vjepa2.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix import

* Min refactoring

* Update HUB_SOURCE and HUB_REPO in conversion script

* Add missing headers

* VJEPA -> V-JEPA in docs

* Add image to doc

* fix style

* fix init weights

* change checkpoint name in modeling tests

---------

Co-authored-by: Koustuv Sinha <koustuv.sinha@mail.mcgill.ca>
Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Co-authored-by: Koustuv Sinha <koustuvsinha@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Pavel Iakubovskii 2025-06-11 15:00:08 +01:00 committed by GitHub
parent a6f0e2b64a
commit 84710a4291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1919 additions and 0 deletions

View File

@ -905,6 +905,8 @@
- sections: - sections:
- local: model_doc/timesformer - local: model_doc/timesformer
title: TimeSformer title: TimeSformer
- local: model_doc/vjepa2
title: V-JEPA 2
- local: model_doc/videomae - local: model_doc/videomae
title: VideoMAE title: VideoMAE
- local: model_doc/vivit - local: model_doc/vivit

View File

@ -0,0 +1,82 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
</div>
</div>
# V-JEPA 2
V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vjepa.gif" alt="drawing" width="600"/>
</div>
You can find all original V-JEPA2 checkpoints under the [V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) collection.
This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yonigozlan](https://huggingface.co/yonigozlan) and [qubvel](https://huggingface.co/qubvel-hf). The original code can be found [here](https://github.com/facebookresearch/vjepa2).
## Usage example
The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
```py
import torch
from torchcodec.decoders import VideoDecoder
import numpy as np
processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
model = AutoModel.from_pretrained(
"facebook/vjepa2-vitl-fpc64-256",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
video = processor(video, return_tensors="pt").to(model.device)
outputs = model(**video)
# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()`
encoder_outputs = outputs.last_hidden_state
# V-JEPA 2 predictor outputs
predictor_outputs = outputs.predictor_output.last_hidden_state
```
## VJEPA2Config
[[autodoc]] VJEPA2Config
## VJEPA2Model
[[autodoc]] VJEPA2Model
- forward
## VJEPA2VideoProcessor
[[autodoc]] VJEPA2VideoProcessor

View File

@ -323,6 +323,7 @@ if TYPE_CHECKING:
from .vitpose_backbone import * from .vitpose_backbone import *
from .vits import * from .vits import *
from .vivit import * from .vivit import *
from .vjepa2 import *
from .wav2vec2 import * from .wav2vec2 import *
from .wav2vec2_bert import * from .wav2vec2_bert import *
from .wav2vec2_conformer import * from .wav2vec2_conformer import *

View File

@ -365,6 +365,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("vitpose_backbone", "VitPoseBackboneConfig"), ("vitpose_backbone", "VitPoseBackboneConfig"),
("vits", "VitsConfig"), ("vits", "VitsConfig"),
("vivit", "VivitConfig"), ("vivit", "VivitConfig"),
("vjepa2", "VJEPA2Config"),
("wav2vec2", "Wav2Vec2Config"), ("wav2vec2", "Wav2Vec2Config"),
("wav2vec2-bert", "Wav2Vec2BertConfig"), ("wav2vec2-bert", "Wav2Vec2BertConfig"),
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"), ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
@ -750,6 +751,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("vitpose_backbone", "ViTPoseBackbone"), ("vitpose_backbone", "ViTPoseBackbone"),
("vits", "VITS"), ("vits", "VITS"),
("vivit", "ViViT"), ("vivit", "ViViT"),
("vjepa2", "VJEPA2Model"),
("wav2vec2", "Wav2Vec2"), ("wav2vec2", "Wav2Vec2"),
("wav2vec2-bert", "Wav2Vec2-BERT"), ("wav2vec2-bert", "Wav2Vec2-BERT"),
("wav2vec2-conformer", "Wav2Vec2-Conformer"), ("wav2vec2-conformer", "Wav2Vec2-Conformer"),

View File

@ -336,6 +336,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("vitdet", "VitDetModel"), ("vitdet", "VitDetModel"),
("vits", "VitsModel"), ("vits", "VitsModel"),
("vivit", "VivitModel"), ("vivit", "VivitModel"),
("vjepa2", "VJEPA2Model"),
("wav2vec2", "Wav2Vec2Model"), ("wav2vec2", "Wav2Vec2Model"),
("wav2vec2-bert", "Wav2Vec2BertModel"), ("wav2vec2-bert", "Wav2Vec2BertModel"),
("wav2vec2-conformer", "Wav2Vec2ConformerModel"), ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),

View File

@ -56,6 +56,7 @@ else:
("qwen2_vl", "Qwen2VLVideoProcessor"), ("qwen2_vl", "Qwen2VLVideoProcessor"),
("smolvlm", "SmolVLMVideoProcessor"), ("smolvlm", "SmolVLMVideoProcessor"),
("video_llava", "VideoLlavaVideoProcessor"), ("video_llava", "VideoLlavaVideoProcessor"),
("vjepa2", "VJEPA2VideoProcessor"),
] ]
) )

View File

@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_vjepa2 import *
from .modeling_vjepa2 import *
from .video_processing_vjepa2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,146 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VJEPA 2 model configuration"""
from ...configuration_utils import PretrainedConfig
class VJEPA2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VJEPA2Model`]. It is used to instantiate an
VJEPA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VJEPA2
[facebook/vjepa2-vitl-fpc64-256](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
crop_size (`int`, *optional*, defaults to 256):
Input resolution of the model
frames_per_clip (`int`, *optional*, defaults to 64):
The number of frames the model has been pretrained with. Does not impact inference.
tubelet_size (`int`, *optional*, defaults to 2):
The number of temporal frames used for a single rastor, check paper for more information.
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers
in_chans (`int`, *optional*, defaults to 3):
The number of input channels
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Encoder
num_hidden_layers (`int`, *optional*, defaults to 24):
The number of hidden layers
drop_path_rate (`float`, *optional*, defaults to 0.0):
Stochastic depth rate per sample (when applied in the main path of residual layers).
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of the hidden size of the MLPs used in Encoder relative to the `hidden_size`.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for attentions.
The dropout probability for all fully connected layers.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
pred_hidden_size (`int`, *optional*, defaults to 384):
Dimensionality of the predictor layers
pred_num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Predictor
pred_num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Predictor
pred_num_mask_tokens (`int`, *optional*, defaults to 10):
Define the number of mask tokens to use in the Predictor
pred_zero_init_mask_tokens (`bool`, *optional*, defaults to `True`):
Initialize the mask tokens in the predictor with 0.
pred_mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of the hidden size of the MLPs used in Predictor relative to the `pred_hidden_size`.
Example:
```python
>>> from transformers import VJEPA2Config, VJEPA2Model
>>> # Initializing a VJEPA2 vjepa2-vitl-fpc64-256 style configuration
>>> configuration = VJEPA2Config()
>>> # Initializing a model (with random weights) from the vjepa2-vitl-fpc64-256 style configuration
>>> model = VJEPA2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "vjepa2"
def __init__(
self,
patch_size=16,
crop_size=256,
frames_per_clip=64,
tubelet_size=2,
hidden_size=1024,
in_chans=3,
num_attention_heads=16,
num_hidden_layers=24,
drop_path_rate=0.0,
mlp_ratio=4.0,
layer_norm_eps=1e-6,
qkv_bias=True,
attention_probs_dropout_prob=0.0,
hidden_act="gelu",
initializer_range=0.02,
# predictor params
pred_hidden_size=384,
pred_num_attention_heads=12,
pred_num_hidden_layers=12,
pred_num_mask_tokens=10,
pred_zero_init_mask_tokens=True,
pred_mlp_ratio=4.0,
**kwargs,
):
super().__init__(**kwargs)
self.crop_size = crop_size
self.frames_per_clip = frames_per_clip
self.patch_size = patch_size
self.tubelet_size = tubelet_size
self.hidden_size = hidden_size
self.in_chans = in_chans
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.drop_path_rate = drop_path_rate
self.mlp_ratio = mlp_ratio
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.image_size = crop_size
# predictor params
self.pred_hidden_size = pred_hidden_size
self.pred_num_attention_heads = pred_num_attention_heads
self.pred_num_hidden_layers = pred_num_hidden_layers
self.pred_num_mask_tokens = pred_num_mask_tokens
self.pred_zero_init_mask_tokens = pred_zero_init_mask_tokens
self.pred_mlp_ratio = pred_mlp_ratio
__all__ = ["VJEPA2Config"]

View File

@ -0,0 +1,346 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import tempfile
from pathlib import Path
import numpy as np
import requests
import torch
from huggingface_hub import HfApi
from PIL import Image
from transformers import VJEPA2Config, VJEPA2Model, VJEPA2VideoProcessor
from transformers.models.vjepa2.modeling_vjepa2 import apply_masks
HUB_REPO = "https://github.com/facebookresearch/vjepa2"
HUB_SOURCE = "github"
HUB_MODELS = {
"vit_large": "facebook/vjepa2-vitl-fpc64-256",
"vit_huge": "facebook/vjepa2-vith-fpc64-256",
"vit_giant": "facebook/vjepa2-vitg-fpc64-256",
"vit_giant_384": "facebook/vjepa2-vitg-fpc64-384",
}
S3_MODELS = {
"vit_large": "https://dl.fbaipublicfiles.com/vjepa2/vitl.pt",
"vit_huge": "https://dl.fbaipublicfiles.com/vjepa2/vith.pt",
"vit_giant": "https://dl.fbaipublicfiles.com/vjepa2/vitg.pt",
"vit_giant_384": "https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt",
}
TOKEN = os.environ.get("HF_TOKEN", None)
def get_vjepa2_config(model_name):
# size of the architecture
if model_name == "vit_large":
return VJEPA2Config(
crop_size=256,
frames_per_clip=64,
hidden_size=1024,
num_attention_heads=16,
num_hidden_layers=24,
mlp_ratio=4,
pred_hidden_size=384,
pred_num_attention_heads=12,
pred_num_hidden_layers=12,
pred_num_mask_tokens=10,
)
elif model_name == "vit_huge":
return VJEPA2Config(
crop_size=256,
frames_per_clip=64,
hidden_size=1280,
num_attention_heads=16,
num_hidden_layers=32,
mlp_ratio=4,
pred_hidden_size=384,
pred_num_attention_heads=12,
pred_num_hidden_layers=12,
pred_num_mask_tokens=10,
)
elif model_name == "vit_giant":
return VJEPA2Config(
crop_size=256,
frames_per_clip=64,
hidden_size=1408,
num_attention_heads=22,
num_hidden_layers=40,
mlp_ratio=48 / 11,
pred_hidden_size=384,
pred_num_attention_heads=12,
pred_num_hidden_layers=12,
pred_num_mask_tokens=10,
)
elif model_name == "vit_giant_384":
return VJEPA2Config(
crop_size=384,
frames_per_clip=64,
hidden_size=1408,
num_attention_heads=22,
num_hidden_layers=40,
mlp_ratio=48 / 11,
pred_hidden_size=384,
pred_num_attention_heads=12,
pred_num_hidden_layers=12,
pred_num_mask_tokens=10,
)
else:
raise ValueError("Model not supported")
def convert_encoder_keys(model_state_dict, og_encoder_state_dict, config):
emb_dim = config.hidden_size
for key, val in og_encoder_state_dict.copy().items():
val = og_encoder_state_dict.pop(key)
key = key.replace("module.backbone.", "")
if key.startswith("blocks."):
key = key.replace("blocks.", "encoder.layer.")
if "attn." in key:
key = key.replace("attn.", "attention.")
if key == "pos_embed":
key = "encoder.embeddings.position_embeddings"
if "patch_embed." in key:
key = key.replace("patch_embed.", "encoder.embeddings.patch_embeddings.")
if key.startswith("norm."):
key = key.replace("norm.", "encoder.layernorm.")
if "qkv." in key:
prefix, suffix = key.split("qkv")
if "bias" in suffix:
q_e, k_e, v_e = (
val[0:emb_dim],
val[emb_dim : emb_dim * 2],
val[emb_dim * 2 :],
)
else:
q_e, k_e, v_e = (
val[0:emb_dim, :],
val[emb_dim : emb_dim * 2, :],
val[emb_dim * 2 :, :],
)
og_encoder_state_dict[prefix + "query" + suffix] = q_e
og_encoder_state_dict[prefix + "key" + suffix] = k_e
og_encoder_state_dict[prefix + "value" + suffix] = v_e
else:
og_encoder_state_dict[key] = val
return og_encoder_state_dict
def convert_predictor_keys(model_state_dict, og_predictor_state_dict, config):
emb_dim = config.pred_hidden_size
if "predictor_pos_embed" in og_predictor_state_dict:
del og_predictor_state_dict["predictor_pos_embed"]
# update predictor weights
mask_tokens = {}
mask_token_keys_to_delete = []
for key, val in og_predictor_state_dict.copy().items():
val = og_predictor_state_dict.pop(key)
key = key.replace("module.backbone.", "")
if key.startswith("predictor_blocks."):
key = key.replace("predictor_blocks.", "predictor.layer.")
if "attn." in key:
key = key.replace("attn.", "attention.")
if key == "predictor_pos_embed":
key = "predictor.embeddings.position_embeddings"
if "predictor_embed." in key:
key = key.replace("predictor_embed.", "predictor.embeddings.predictor_embeddings.")
if "mask_tokens." in key:
mask_tokens[key.split("mask_tokens.")[-1]] = val
mask_token_keys_to_delete.append(key)
# key = key.replace("mask_tokens.", "predictor.embeddings.mask_tokens.")
if key.startswith("predictor_norm."):
key = key.replace("predictor_norm.", "predictor.layernorm.")
if key.startswith("predictor_proj."):
key = key.replace("predictor_proj.", "predictor.proj.")
if "qkv." in key:
prefix, suffix = key.split("qkv")
if "bias" in suffix:
q_e, k_e, v_e = (
val[0:emb_dim],
val[emb_dim : emb_dim * 2],
val[emb_dim * 2 :],
)
else:
q_e, k_e, v_e = (
val[0:emb_dim, :],
val[emb_dim : emb_dim * 2, :],
val[emb_dim * 2 :, :],
)
og_predictor_state_dict[prefix + "query" + suffix] = q_e
og_predictor_state_dict[prefix + "key" + suffix] = k_e
og_predictor_state_dict[prefix + "value" + suffix] = v_e
else:
og_predictor_state_dict[key] = val
mask_tokens = torch.stack([mask_tokens[f"{i}"] for i in range(len(mask_tokens))], dim=0)
for k in mask_token_keys_to_delete:
del og_predictor_state_dict[k]
og_predictor_state_dict["predictor.embeddings.mask_tokens"] = mask_tokens
return og_predictor_state_dict
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return image
def upload_original_ckpts(model_name):
hf_repo = HUB_MODELS[model_name]
original_ckpt = S3_MODELS[model_name]
print(f"Uploading original checkpoint for vjepa2 {model_name} to {hf_repo}/original/")
with tempfile.NamedTemporaryFile() as fn:
local_path = fn.name
torch.hub.download_url_to_file(original_ckpt, local_path)
api = HfApi()
api.upload_file(
repo_id=hf_repo,
path_or_fileobj=local_path,
path_in_repo="original/model.pth",
repo_type="model",
token=TOKEN,
)
print("Uploading complete")
@torch.no_grad()
def convert_and_test_vjepa2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our VJEPA2 structure.
"""
config = get_vjepa2_config(model_name)
# load original model from torch hub
original_encoder, original_predictor = torch.hub.load(HUB_REPO, "vjepa2_" + model_name, source=HUB_SOURCE)
original_encoder.eval()
original_predictor.eval()
original_preprocessor = torch.hub.load(
HUB_REPO, "vjepa2_preprocessor", source=HUB_SOURCE, crop_size=config.crop_size
)
# load state_dict of original model, remove and rename some keys
encoder_state_dict = original_encoder.state_dict()
decoder_state_dict = original_predictor.state_dict()
model = VJEPA2Model(config).eval()
state_dict = model.state_dict()
og_encoder_sd = convert_encoder_keys(state_dict, encoder_state_dict, config)
og_predictor_sd = convert_predictor_keys(state_dict, decoder_state_dict, config)
og_state_dict = og_encoder_sd
og_state_dict.update(og_predictor_sd)
model.load_state_dict(og_state_dict)
# load image
image = prepare_img()
image = torch.Tensor(np.array(image)).unsqueeze(0).permute(0, 3, 1, 2)
print("Input shape: ", image.shape)
crop_size = config.crop_size
processor = VJEPA2VideoProcessor(crop_size=crop_size)
pr_out = processor(image, return_tensors="pt")
pixel_values_videos = pr_out.pixel_values_videos
# run original preprocessor
original_pixel_values = original_preprocessor(image)
assert original_pixel_values[0].permute(1, 0, 2, 3).shape == pixel_values_videos[0].shape
assert torch.allclose(original_pixel_values[0].permute(1, 0, 2, 3), pixel_values_videos[0], atol=1e-3)
with torch.no_grad():
# reshape and move to gpu
if pixel_values_videos.size(1) == 1:
pixel_values_videos = pixel_values_videos.repeat(1, config.frames_per_clip, 1, 1, 1)
# pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4) # B x C x T x H x W
pixel_values_videos = pixel_values_videos.to(device="cuda", dtype=torch.float32)
original_encoder = original_encoder.to(device="cuda", dtype=torch.float32)
original_predictor = original_predictor.to(device="cuda", dtype=torch.float32)
model = model.to(device="cuda", dtype=torch.float32)
# forward
original_encoder_outputs = original_encoder(pixel_values_videos.permute(0, 2, 1, 3, 4))
B, N, _ = original_encoder_outputs.shape
# test full mask
context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
predictor_mask = context_mask
original_predictor_outputs = original_predictor(original_encoder_outputs, context_mask, predictor_mask)
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3)
predictor_outputs = outputs.predictor_output
assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3)
# test partial mask
window_size = 256
mask = torch.arange(N, device=pixel_values_videos.device).unsqueeze(0)
context_mask = [mask[:, :window_size].repeat((B, 1))]
predictor_mask = [mask[:, window_size : window_size * 2].repeat((B, 1))]
original_predictor_outputs = original_predictor(
apply_masks(original_encoder_outputs, context_mask),
context_mask,
predictor_mask,
)
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3)
predictor_outputs = outputs.predictor_output
assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
name = HUB_MODELS[model_name]
model.push_to_hub(name, private=True)
processor.push_to_hub(name, private=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="vit_large",
type=str,
choices=[
"vit_large",
"vit_huge",
"vit_giant",
"vit_giant_384",
],
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model to the 🤗 hub.",
)
parser.add_argument("--upload_original", action="store_true", help="upload the original checkpoint")
args = parser.parse_args()
convert_and_test_vjepa2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
if args.upload_original:
upload_original_ckpts(args.model_name)

View File

@ -0,0 +1,903 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, dataclass
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from .configuration_vjepa2 import VJEPA2Config
logger = logging.get_logger(__name__)
@dataclass
class VJEPA2WithMaskedInputPredictorOutput(ModelOutput):
"""
VJEPA Predictor outputs that also contains the masked encoder outputs
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
masked_hidden_state (`torch.FloatTensor`), *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
target_hidden_state (`torch.FloatTensor`), *optional*):
Returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs.
"""
last_hidden_state: torch.FloatTensor
masked_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
target_hidden_state: Optional[torch.FloatTensor] = None
@dataclass
class VJEPA2WithMaskedInputModelOutput(ModelOutput):
"""
VJEPA outputs that also contains the masked encoder outputs
Optionally contains the predictor outputs
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
masked_hidden_state (`torch.FloatTensor`), *optional*):
Returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
Returns the output from the Predictor module
"""
last_hidden_state: torch.FloatTensor
masked_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
predictor_output: Optional[VJEPA2WithMaskedInputPredictorOutput] = None
def to_tuple(self):
output = list(super().to_tuple())
if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput):
output[-1] = output[-1].to_tuple()
return tuple(output)
class VJEPA2PatchEmbeddings3D(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(
self,
config: VJEPA2Config,
hidden_size: int = 1024,
):
super().__init__()
self.patch_size = config.patch_size
self.tubelet_size = config.tubelet_size
self.hidden_size = hidden_size
self.proj = nn.Conv3d(
in_channels=config.in_chans,
out_channels=hidden_size,
kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
stride=(config.tubelet_size, config.patch_size, config.patch_size),
)
@staticmethod
def num_patches(config):
return (
(config.frames_per_clip // config.tubelet_size)
* (config.crop_size // config.patch_size)
* (config.crop_size // config.patch_size)
)
def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2)
return x
class VJEPA2Embeddings(nn.Module):
"""
Construct mask token, position and patch embeddings.
"""
def __init__(self, config: VJEPA2Config, hidden_size: int = 1024):
super().__init__()
self.config = config
self.hidden_size = hidden_size
self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)
self.num_patches = self.patch_embeddings.num_patches
self.patch_size = config.patch_size
def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
num_frames = pixel_values_videos.shape[1]
# Swap `frames` and `channels` dims, the result is:
# (batch_size, channels, num_frames, height, width)
pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
# For some cases, if the input vision (image/video) consists of num_frames < tubelet_size,
# then embedding lookup fails. In these cases, we duplicate the frames.
if num_frames < self.config.tubelet_size:
pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1)
target_dtype = self.patch_embeddings.proj.weight.dtype
pixel_values_videos = pixel_values_videos.to(dtype=target_dtype)
embeddings = self.patch_embeddings(pixel_values_videos)
return embeddings
# Adapted from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def rotate_queries_or_keys(x, pos):
B, num_heads, N, D = x.size()
# similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
# they are computing this every time. instead HF style is to compute the inv_freq once and store it
# -- compute angle for each position
omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
omega /= D / 2.0
omega = 1.0 / 10000**omega # (D/2,)
freq = torch.einsum("..., f -> ... f", pos, omega) # (..., N, D/2), outer product
# -- build rotation matrix and apply
emb_sin = freq.sin() # (..., N, D/2)
emb_cos = freq.cos() # (..., N, D/2)
emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
# --
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1)
y = y.flatten(-2)
return (x * emb_cos) + (y * emb_sin)
class VJEPA2RopeAttention(nn.Module):
def __init__(
self,
config: VJEPA2Config,
hidden_size: int = 1024,
num_attention_heads: int = 16,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"The hidden size {(hidden_size,)} is not a multiple of the number of attention "
f"heads {num_attention_heads}."
)
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
self.proj = nn.Linear(hidden_size, hidden_size)
self.dropout_prob = config.attention_probs_dropout_prob
self.dropout = nn.Dropout(self.dropout_prob)
self.grid_size = self.config.crop_size // self.config.patch_size
self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size
self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def _get_frame_pos(self, ids):
tokens_per_frame = int(self.grid_size * self.grid_size)
return ids // tokens_per_frame
def _get_height_pos(self, ids):
# Remove frame component from ids
tokens_per_frame = int(self.grid_size * self.grid_size)
frame_ids = self._get_frame_pos(ids)
ids = ids - tokens_per_frame * frame_ids
# --
tokens_per_row = self.grid_size
return ids // tokens_per_row
def get_position_ids(self, x, masks=None):
device = x.device
token_size = x.size(1)
# Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask,
# as 1d vector is broadcasted to the correct shapes.
if masks is not None:
ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1)
else:
ids = torch.arange(token_size, device=device)
# change to allow for extrapolation
tokens_per_frame = int(self.grid_size * self.grid_size)
frame_ids = self._get_frame_pos(ids)
# --
tokens_per_row = self.grid_size
height_ids = self._get_height_pos(ids)
# --
# Remove frame component from ids (1st term) and height component (2nd term)
width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
return frame_ids, height_ids, width_ids
def apply_rotary_embeddings(self, qk, pos_ids):
d_mask, h_mask, w_mask = pos_ids
s = 0
qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask)
s += self.d_dim
qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask)
s += self.h_dim
qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask)
s += self.w_dim
# Combine rotated dimension
if s < self.attention_head_size:
qkr = qk[..., s:]
qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
else:
qk = torch.cat([qkd, qkh, qkw], dim=-1)
return qk
def forward(
self,
hidden_states,
position_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
head_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
pos_ids = self.get_position_ids(hidden_states, masks=position_mask)
key_layer = self.apply_rotary_embeddings(key_layer, pos_ids)
query_layer = self.apply_rotary_embeddings(query_layer, pos_ids)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = self.proj(context_layer.reshape(new_context_layer_shape))
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Adapted from transformers.models.beit.modeling_dinov2.drop_path
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output
# Adapted from transformers.models.beit.modeling_beit.BeitDropPath
class VJEPA2DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None):
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class VJEPA2MLP(nn.Module):
def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0):
super().__init__()
in_features = out_features = hidden_size
hidden_features = int(hidden_size * mlp_ratio)
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
self.activation = ACT2FN[config.hidden_act]
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class VJEPA2Layer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the original implementation."""
def __init__(
self,
config: VJEPA2Config,
drop_path_rate: float = 0.0,
hidden_size: int = 1024,
num_attention_heads: int = 16,
mlp_ratio: float = 4.0,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.mlp_ratio = mlp_ratio
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)
def forward(
self,
hidden_states: torch.Tensor,
position_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
# Self-Attention
residual = hidden_states
hidden_states = self.norm1(hidden_states)
self_attention_outputs = self.attention(
hidden_states,
position_mask=position_mask, # position mask for context/target selection
head_mask=head_mask, # head mask is applied at F.scaled_dot_product_attention
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
hidden_states = self.drop_path(attention_output) + residual
# MLP
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
# Add self attentions if we output attention weights
outputs = self_attention_outputs[1:]
outputs = (hidden_states,) + outputs
return outputs
class VJEPA2Encoder(nn.Module):
def __init__(self, config: VJEPA2Config):
super().__init__()
self.config = config
self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
drop_path_rates = [
(config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0)
for i in range(config.num_hidden_layers)
]
self.layer = nn.ModuleList(
[
VJEPA2Layer(
config,
drop_path_rate=drop_path_rates[i],
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
mlp_ratio=config.mlp_ratio,
)
for i in range(config.num_hidden_layers)
]
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
pixel_values_videos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> BaseModelOutput:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
hidden_states = self.embeddings(pixel_values_videos)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(hidden_states, None, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layernorm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def apply_masks(x, masks) -> torch.Tensor:
"""
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
"""
all_x = []
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.gather(x, dim=1, index=mask_keep)]
return torch.cat(all_x, dim=0)
class VJEPA2PredictorEmbeddings(nn.Module):
"""
Construct mask token, position and patch embeddings.
"""
def __init__(self, config: VJEPA2Config):
super().__init__()
self.config = config
self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size)
self.num_mask_tokens = 0
self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens
self.num_mask_tokens = config.pred_num_mask_tokens
self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size))
self.patch_size = config.patch_size
self.config = config
@staticmethod
def num_patches(config):
if config.frames_per_clip > 1:
return (
(config.frames_per_clip // config.tubelet_size)
* (config.crop_size // config.patch_size)
* (config.crop_size // config.patch_size)
)
else:
return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size)
def forward(
self,
hidden_states: torch.Tensor,
context_mask: List[torch.Tensor],
target_mask: List[torch.Tensor],
mask_index: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
hidden_states : encoder outputs (context)
context_mask: tokens of the context (outputs from the encoder)
target_mask: tokens to predict
mask_index: index of the target mask to choose (useful for multiclip?)
"""
B = hidden_states.size(0)
context = self.predictor_embeddings(hidden_states)
# Make target tokens
mask_index = mask_index % self.num_mask_tokens
target = self.mask_tokens[mask_index]
# Note: this is problematic if the config isn't initialized with the right frames_per_clip value,
# e.g. for scenarios if we want to run predictor for more tokens than in the config.
# target = target.repeat(B, self.num_patches(self.config), 1)
# Remedy: use the provided target mask to get the max patch num
max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch
target = target.repeat(B, max_patch_num, 1)
target = apply_masks(target, target_mask)
# Concatenate context & target tokens
context = context.repeat(len(context_mask), 1, 1)
embeddings = torch.cat([context, target], dim=1)
# Positions of context & target tokens
cm = torch.cat(context_mask, dim=0)
tm = torch.cat(target_mask, dim=0)
masks = torch.cat([cm, tm], dim=1)
return embeddings, masks
class VJEPA2Predictor(nn.Module):
def __init__(self, config: VJEPA2Config):
super().__init__()
self.config = config
self.gradient_checkpointing = False
self.embeddings = VJEPA2PredictorEmbeddings(config)
drop_path_rates = [
(
config.drop_path_rate * i / (config.pred_num_hidden_layers - 1)
if config.pred_num_hidden_layers > 1
else 0.0
)
for i in range(config.pred_num_hidden_layers)
]
self.layer = nn.ModuleList(
[
VJEPA2Layer(
config,
drop_path_rate=drop_path_rates[i],
hidden_size=config.pred_hidden_size,
num_attention_heads=config.pred_num_attention_heads,
mlp_ratio=config.pred_mlp_ratio,
)
for i in range(config.pred_num_hidden_layers)
]
)
self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps)
self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
position_masks = torch.gather(position_masks, dim=1, index=argsort)
hidden_states = torch.gather(
hidden_states,
dim=1,
index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
)
if head_mask is not None and head_mask[0] is not None:
head_mask = head_mask.permute(1, 0, 2, 3, 4)
argsort_4d = (
argsort.unsqueeze(1)
.unsqueeze(1)
.expand(-1, head_mask.size(1), head_mask.size(2), -1)
.unsqueeze(-1)
.expand(-1, -1, -1, -1, head_mask.size(-1))
)
head_mask = torch.gather(head_mask, dim=3, index=argsort_4d)
argsort_5d = (
argsort.unsqueeze(1)
.unsqueeze(1)
.unsqueeze(1)
.expand(-1, head_mask.size(1), head_mask.size(2), head_mask.size(3), -1)
)
head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
head_mask = head_mask.permute(1, 0, 2, 3, 4)
return hidden_states, position_masks, head_mask
def unsort_tokens(self, hidden_states, argsort):
reverse_argsort = torch.argsort(argsort, dim=1)
hidden_states = torch.gather(
hidden_states,
dim=1,
index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
)
return hidden_states
@can_return_tuple
def forward(
self,
encoder_hidden_states: torch.Tensor,
context_mask: List[torch.Tensor],
target_mask: List[torch.Tensor],
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> BaseModelOutput:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
# mask out the encoder hidden states
# this is implemented here as in VJEPA training a separate encoder is used for target
encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask)
_, N_ctxt, D = encoder_hidden_states.shape
hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask)
# Put tokens in sorted order
argsort = torch.argsort(position_masks, dim=1) # [B, N]
hidden_states, position_masks, head_mask = self.sort_tokens(hidden_states, position_masks, argsort, head_mask)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(hidden_states, position_masks, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.layernorm(hidden_states)
# unsort and extract the predicted tokens
hidden_states = self.unsort_tokens(hidden_states, argsort)
hidden_states = hidden_states[:, N_ctxt:]
# projection
hidden_states = self.proj(hidden_states)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@auto_docstring
class VJEPA2PreTrainedModel(PreTrainedModel):
config_class = VJEPA2Config
base_model_prefix = "vjepa2"
main_input_name = "pixel_values_videos"
supports_gradient_checkpointing = True
_no_split_modules = ["VJEPA2Layer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(
self,
module: Union[
nn.Linear,
nn.Conv2d,
nn.LayerNorm,
VJEPA2Embeddings,
VJEPA2PredictorEmbeddings,
],
):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, VJEPA2PredictorEmbeddings):
if not module.zero_init_mask_tokens:
module.mask_token = nn.init.trunc_normal_(
module.mask_token.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.mask_token.dtype)
else:
module.mask_tokens.data.zero_()
def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
"""
Inputs:
- head_mask: bsz x seq_length x seq_length | None
Returns
- [num_hidden_layers x batch x num_heads x seq_length x seq_length] | [num_hidden_layers]
"""
if head_mask is not None:
head_mask = head_mask.unsqueeze(1).unsqueeze(0)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
@auto_docstring
class VJEPA2Model(VJEPA2PreTrainedModel):
def __init__(self, config: VJEPA2Config):
super().__init__(config)
self.config = config
self.encoder = VJEPA2Encoder(config)
self.predictor = VJEPA2Predictor(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D:
return self.encoder.embeddings.patch_embeddings
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values_videos: torch.Tensor,
context_head_mask: Optional[torch.Tensor] = None,
context_mask: Optional[List[torch.Tensor]] = None,
target_head_mask: Optional[torch.Tensor] = None,
target_mask: Optional[List[torch.Tensor]] = None,
skip_predictor: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> VJEPA2WithMaskedInputModelOutput:
r"""
pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`, required):
The input video pixels which is processed by VJEPA2VideoProcessor.
context_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the context.
target_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the target.
context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
available to the predictor.
target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
The mask position ids indicating which encoder output patches are going to be used as a prediction target
for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
that the predictor should predict all encoder patches.
skip_predictor (bool):
flag to skip the predictor forward, useful if you just need the encoder outputs
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if pixel_values_videos is None:
raise ValueError("You have to specify pixel_values_videos")
# Prepare head mask if needed
context_head_mask = _convert_head_mask_to_5d(context_head_mask, self.config.num_hidden_layers)
target_head_mask = _convert_head_mask_to_5d(target_head_mask, self.config.pred_num_hidden_layers)
encoder_outputs: BaseModelOutput = self.encoder(
pixel_values_videos=pixel_values_videos,
head_mask=context_head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs.last_hidden_state
if context_mask is None and target_mask is None:
B = pixel_values_videos.size(0)
N = sequence_output.size(1) # ensure we are using dynamic patch size
context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
if not skip_predictor:
predictor_outputs: BaseModelOutput = self.predictor(
encoder_hidden_states=sequence_output,
context_mask=context_mask,
target_mask=target_mask,
head_mask=target_head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
predictor_output = VJEPA2WithMaskedInputPredictorOutput(
last_hidden_state=predictor_outputs.last_hidden_state,
target_hidden_state=apply_masks(sequence_output, target_mask),
hidden_states=predictor_outputs.hidden_states,
attentions=predictor_outputs.attentions,
)
else:
predictor_output = None
encoder_output = VJEPA2WithMaskedInputModelOutput(
last_hidden_state=sequence_output,
masked_hidden_state=apply_masks(sequence_output, context_mask),
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
predictor_output=predictor_output,
)
return encoder_output
def get_vision_features(self, pixel_values_videos) -> torch.Tensor:
encoder_output = self.forward(pixel_values_videos)
return encoder_output.last_hidden_state
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"]

View File

@ -0,0 +1,59 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Video processor class for VJEPA2."""
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
)
from ...processing_utils import Unpack, VideosKwargs
from ...utils import is_vision_available
from ...utils.import_utils import requires
from ...video_processing_utils import BaseVideoProcessor
if is_vision_available():
from ...image_utils import PILImageResampling
class VJEPA2VideoProcessorInitKwargs(VideosKwargs): ...
@requires(backends=("torchvision",))
class VJEPA2VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
size = {"shortest_edge": int(256 * 256 / 224)}
crop_size = 256
do_resize = True
do_rescale = True
do_center_crop = True
do_normalize = True
valid_kwargs = VJEPA2VideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[VJEPA2VideoProcessorInitKwargs]):
crop_size = kwargs.get("crop_size", 256)
if not isinstance(crop_size, int):
if not isinstance(crop_size, dict) or "height" not in crop_size:
raise ValueError("crop_size must be an integer or a dictionary with a 'height' key")
crop_size = crop_size["height"]
resize_size = int(crop_size * 256 / 224)
kwargs["size"] = {"shortest_edge": resize_size}
super().__init__(**kwargs)
__all__ = ["VJEPA2VideoProcessor"]

View File

@ -166,6 +166,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"t5", "t5",
"trocr", "trocr",
"vit", "vit",
"vjepa2",
"xglm", "xglm",
"wav2vec2", "wav2vec2",
# "xlnet", # "xlnet",

View File

View File

@ -0,0 +1,345 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch V-JEPA2 model."""
import unittest
import numpy as np
from transformers import VJEPA2Config
from transformers.testing_utils import (
is_flaky,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
from ...test_video_processing_common import (
prepare_video_inputs,
)
if is_torch_available():
import torch
from torch import nn
from transformers import VJEPA2Model
if is_vision_available():
from PIL import Image
from transformers import AutoVideoProcessor
VJEPA_HF_MODEL = "facebook/vjepa2-vitl-fpc64-256"
class VJEPA2ModelTester:
def __init__(
self,
parent,
batch_size=2,
image_size=16,
patch_size=16,
num_channels=3,
hidden_size=32,
num_hidden_layers=4,
num_attention_heads=2,
num_frames=2,
mlp_ratio=1,
pred_hidden_size=32,
pred_num_attention_heads=2,
pred_num_hidden_layers=2,
pred_num_mask_tokens=10,
is_training=False,
attn_implementation="sdpa",
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_frames = num_frames
self.mlp_ratio = mlp_ratio
self.pred_hidden_size = pred_hidden_size
self.pred_num_attention_heads = pred_num_attention_heads
self.pred_num_hidden_layers = pred_num_hidden_layers
self.pred_num_mask_tokens = pred_num_mask_tokens
self.attn_implementation = attn_implementation
self.is_training = is_training
self.mask_ratio = mask_ratio
num_patches = ((image_size // patch_size) ** 2) * (num_frames // 2)
self.seq_length = num_patches
self.num_masks = int(self.mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values_videos = floats_tensor(
[
self.batch_size,
self.num_frames,
self.num_channels,
self.image_size,
self.image_size,
]
)
config = self.get_config()
return config, pixel_values_videos
def get_config(self):
return VJEPA2Config(
crop_size=self.image_size,
frames_per_clip=self.num_frames,
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
num_hidden_layers=self.num_hidden_layers,
mlp_ratio=self.mlp_ratio,
pred_hidden_size=self.pred_hidden_size,
pred_num_attention_heads=self.pred_num_attention_heads,
pred_num_hidden_layers=self.pred_num_hidden_layers,
pred_num_mask_tokens=self.pred_num_mask_tokens,
)
def create_and_check_model(self, config, pixel_values_videos):
model = VJEPA2Model(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values_videos)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.seq_length, self.hidden_size),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values_videos,
) = config_and_inputs
inputs_dict = {"pixel_values_videos": pixel_values_videos}
return config, inputs_dict
@require_torch
class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as VJEPA2 does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
test_torch_exportable = True
all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
fx_compatible = True
pipeline_model_mapping = {}
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = VJEPA2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=VJEPA2Config, has_text_modality=False, hidden_size=37)
@is_flaky(max_attempts=3, description="`torch.nn.init.trunc_normal_` is flaky.")
def test_initialization(self):
super().test_initialization()
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="VJEPA2 does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="VJEPA2 does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
@slow
def test_model_from_pretrained(self):
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
def prepare_random_video(image_size=256):
videos = prepare_video_inputs(
batch_size=1,
num_frames=16,
num_channels=3,
min_resolution=image_size,
max_resolution=image_size,
equal_resolution=True,
return_tensors="torch",
)
return videos
@require_torch
@require_vision
class VJEPA2ModelIntegrationTest(unittest.TestCase):
@cached_property
def default_video_processor(self):
return AutoVideoProcessor.from_pretrained(VJEPA_HF_MODEL) if is_vision_available() else None
@slow
def test_inference_image(self):
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
video_processor = self.default_video_processor
image = prepare_img()
inputs = video_processor(torch.Tensor(np.array(image)), return_tensors="pt").to(torch_device)
pixel_values_videos = inputs.pixel_values_videos
pixel_values_videos = pixel_values_videos.repeat(1, model.config.frames_per_clip, 1, 1, 1)
# forward pass
with torch.no_grad():
outputs = model(pixel_values_videos)
# verify the last hidden states
expected_shape = torch.Size((1, 8192, 1024))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
device=torch_device,
)
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
@slow
def test_inference_video(self):
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
video_processor = self.default_video_processor
video = prepare_random_video()
inputs = video_processor(video, return_tensors="pt").to(torch_device)
pixel_values_videos = inputs.pixel_values_videos
# forward pass
with torch.no_grad():
outputs = model(pixel_values_videos)
# verify the last hidden states
expected_shape = torch.Size((1, 2048, 1024))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@slow
def test_predictor_outputs(self):
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
video_processor = self.default_video_processor
video = prepare_random_video()
inputs = video_processor(video, return_tensors="pt").to(torch_device)
pixel_values_videos = inputs.pixel_values_videos
# forward pass
with torch.no_grad():
outputs = model(pixel_values_videos)
# verify the last hidden states
expected_shape = torch.Size((1, 2048, 1024))
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
@slow
def test_predictor_full_mask(self):
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
video_processor = self.default_video_processor
video = prepare_random_video()
inputs = video_processor(video, return_tensors="pt").to(torch_device)
pixel_values_videos = inputs.pixel_values_videos
# forward pass
with torch.no_grad():
context_mask = [torch.arange(2048, device=pixel_values_videos.device).unsqueeze(0)]
predictor_mask = context_mask
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
# verify the last hidden states
expected_shape = torch.Size((1, 2048, 1024))
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
@slow
def test_predictor_partial_mask(self):
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
video_processor = self.default_video_processor
video = prepare_random_video()
inputs = video_processor(video, return_tensors="pt").to(torch_device)
pixel_values_videos = inputs.pixel_values_videos
num_patches = 2048
num_masks = 100
# forward pass
with torch.no_grad():
pos_ids = torch.arange(num_patches, device=pixel_values_videos.device)
context_mask = [pos_ids[0 : num_patches - num_masks].unsqueeze(0)]
predictor_mask = [pos_ids[num_patches - num_masks :].unsqueeze(0)]
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
# verify the last hidden states
expected_shape = torch.Size((1, num_masks, 1024))
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)

View File

@ -1301,6 +1301,7 @@ class ModelTesterMixin:
"input_values", "input_values",
"inputs_embeds", "inputs_embeds",
"pixel_values", "pixel_values",
"pixel_values_videos",
"token_type_ids", "token_type_ids",
"visual_feats", "visual_feats",
"visual_pos", "visual_pos",