mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add V-JEPA 2 (#38746)
* adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests --------- Co-authored-by: Koustuv Sinha <koustuv.sinha@mail.mcgill.ca> Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: Koustuv Sinha <koustuvsinha@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
a6f0e2b64a
commit
84710a4291
@ -905,6 +905,8 @@
|
|||||||
- sections:
|
- sections:
|
||||||
- local: model_doc/timesformer
|
- local: model_doc/timesformer
|
||||||
title: TimeSformer
|
title: TimeSformer
|
||||||
|
- local: model_doc/vjepa2
|
||||||
|
title: V-JEPA 2
|
||||||
- local: model_doc/videomae
|
- local: model_doc/videomae
|
||||||
title: VideoMAE
|
title: VideoMAE
|
||||||
- local: model_doc/vivit
|
- local: model_doc/vivit
|
||||||
|
82
docs/source/en/model_doc/vjepa2.md
Normal file
82
docs/source/en/model_doc/vjepa2.md
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
<div style="float: right;">
|
||||||
|
<div class="flex flex-wrap space-x-1">
|
||||||
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# V-JEPA 2
|
||||||
|
|
||||||
|
V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
|
||||||
|
|
||||||
|
<div class="flex justify-center">
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vjepa.gif" alt="drawing" width="600"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
You can find all original V-JEPA2 checkpoints under the [V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) collection.
|
||||||
|
|
||||||
|
|
||||||
|
This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yonigozlan](https://huggingface.co/yonigozlan) and [qubvel](https://huggingface.co/qubvel-hf). The original code can be found [here](https://github.com/facebookresearch/vjepa2).
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
|
||||||
|
The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
"facebook/vjepa2-vitl-fpc64-256",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
|
||||||
|
|
||||||
|
vr = VideoDecoder(video_url)
|
||||||
|
frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
|
||||||
|
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
|
||||||
|
video = processor(video, return_tensors="pt").to(model.device)
|
||||||
|
outputs = model(**video)
|
||||||
|
|
||||||
|
# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()`
|
||||||
|
encoder_outputs = outputs.last_hidden_state
|
||||||
|
|
||||||
|
# V-JEPA 2 predictor outputs
|
||||||
|
predictor_outputs = outputs.predictor_output.last_hidden_state
|
||||||
|
```
|
||||||
|
|
||||||
|
## VJEPA2Config
|
||||||
|
|
||||||
|
[[autodoc]] VJEPA2Config
|
||||||
|
|
||||||
|
## VJEPA2Model
|
||||||
|
|
||||||
|
[[autodoc]] VJEPA2Model
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## VJEPA2VideoProcessor
|
||||||
|
|
||||||
|
[[autodoc]] VJEPA2VideoProcessor
|
@ -323,6 +323,7 @@ if TYPE_CHECKING:
|
|||||||
from .vitpose_backbone import *
|
from .vitpose_backbone import *
|
||||||
from .vits import *
|
from .vits import *
|
||||||
from .vivit import *
|
from .vivit import *
|
||||||
|
from .vjepa2 import *
|
||||||
from .wav2vec2 import *
|
from .wav2vec2 import *
|
||||||
from .wav2vec2_bert import *
|
from .wav2vec2_bert import *
|
||||||
from .wav2vec2_conformer import *
|
from .wav2vec2_conformer import *
|
||||||
|
@ -365,6 +365,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("vitpose_backbone", "VitPoseBackboneConfig"),
|
("vitpose_backbone", "VitPoseBackboneConfig"),
|
||||||
("vits", "VitsConfig"),
|
("vits", "VitsConfig"),
|
||||||
("vivit", "VivitConfig"),
|
("vivit", "VivitConfig"),
|
||||||
|
("vjepa2", "VJEPA2Config"),
|
||||||
("wav2vec2", "Wav2Vec2Config"),
|
("wav2vec2", "Wav2Vec2Config"),
|
||||||
("wav2vec2-bert", "Wav2Vec2BertConfig"),
|
("wav2vec2-bert", "Wav2Vec2BertConfig"),
|
||||||
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
|
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
|
||||||
@ -750,6 +751,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("vitpose_backbone", "ViTPoseBackbone"),
|
("vitpose_backbone", "ViTPoseBackbone"),
|
||||||
("vits", "VITS"),
|
("vits", "VITS"),
|
||||||
("vivit", "ViViT"),
|
("vivit", "ViViT"),
|
||||||
|
("vjepa2", "VJEPA2Model"),
|
||||||
("wav2vec2", "Wav2Vec2"),
|
("wav2vec2", "Wav2Vec2"),
|
||||||
("wav2vec2-bert", "Wav2Vec2-BERT"),
|
("wav2vec2-bert", "Wav2Vec2-BERT"),
|
||||||
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
|
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
|
||||||
|
@ -336,6 +336,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("vitdet", "VitDetModel"),
|
("vitdet", "VitDetModel"),
|
||||||
("vits", "VitsModel"),
|
("vits", "VitsModel"),
|
||||||
("vivit", "VivitModel"),
|
("vivit", "VivitModel"),
|
||||||
|
("vjepa2", "VJEPA2Model"),
|
||||||
("wav2vec2", "Wav2Vec2Model"),
|
("wav2vec2", "Wav2Vec2Model"),
|
||||||
("wav2vec2-bert", "Wav2Vec2BertModel"),
|
("wav2vec2-bert", "Wav2Vec2BertModel"),
|
||||||
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
|
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
|
||||||
|
@ -56,6 +56,7 @@ else:
|
|||||||
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
||||||
("smolvlm", "SmolVLMVideoProcessor"),
|
("smolvlm", "SmolVLMVideoProcessor"),
|
||||||
("video_llava", "VideoLlavaVideoProcessor"),
|
("video_llava", "VideoLlavaVideoProcessor"),
|
||||||
|
("vjepa2", "VJEPA2VideoProcessor"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
29
src/transformers/models/vjepa2/__init__.py
Normal file
29
src/transformers/models/vjepa2/__init__.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_vjepa2 import *
|
||||||
|
from .modeling_vjepa2 import *
|
||||||
|
from .video_processing_vjepa2 import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
146
src/transformers/models/vjepa2/configuration_vjepa2.py
Normal file
146
src/transformers/models/vjepa2/configuration_vjepa2.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""VJEPA 2 model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`VJEPA2Model`]. It is used to instantiate an
|
||||||
|
VJEPA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the VJEPA2
|
||||||
|
[facebook/vjepa2-vitl-fpc64-256](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_size (`int`, *optional*, defaults to 16):
|
||||||
|
The size (resolution) of each patch.
|
||||||
|
crop_size (`int`, *optional*, defaults to 256):
|
||||||
|
Input resolution of the model
|
||||||
|
frames_per_clip (`int`, *optional*, defaults to 64):
|
||||||
|
The number of frames the model has been pretrained with. Does not impact inference.
|
||||||
|
tubelet_size (`int`, *optional*, defaults to 2):
|
||||||
|
The number of temporal frames used for a single rastor, check paper for more information.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimensionality of the encoder layers
|
||||||
|
in_chans (`int`, *optional*, defaults to 3):
|
||||||
|
The number of input channels
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Encoder
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||||
|
The number of hidden layers
|
||||||
|
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||||
|
Stochastic depth rate per sample (when applied in the main path of residual layers).
|
||||||
|
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
||||||
|
Ratio of the hidden size of the MLPs used in Encoder relative to the `hidden_size`.
|
||||||
|
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to add a bias to the queries, keys and values.
|
||||||
|
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability for attentions.
|
||||||
|
The dropout probability for all fully connected layers.
|
||||||
|
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
pred_hidden_size (`int`, *optional*, defaults to 384):
|
||||||
|
Dimensionality of the predictor layers
|
||||||
|
pred_num_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Predictor
|
||||||
|
pred_num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of hidden layers in the Predictor
|
||||||
|
pred_num_mask_tokens (`int`, *optional*, defaults to 10):
|
||||||
|
Define the number of mask tokens to use in the Predictor
|
||||||
|
pred_zero_init_mask_tokens (`bool`, *optional*, defaults to `True`):
|
||||||
|
Initialize the mask tokens in the predictor with 0.
|
||||||
|
pred_mlp_ratio (`float`, *optional*, defaults to 4.0):
|
||||||
|
Ratio of the hidden size of the MLPs used in Predictor relative to the `pred_hidden_size`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import VJEPA2Config, VJEPA2Model
|
||||||
|
|
||||||
|
>>> # Initializing a VJEPA2 vjepa2-vitl-fpc64-256 style configuration
|
||||||
|
>>> configuration = VJEPA2Config()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the vjepa2-vitl-fpc64-256 style configuration
|
||||||
|
>>> model = VJEPA2Model(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "vjepa2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=16,
|
||||||
|
crop_size=256,
|
||||||
|
frames_per_clip=64,
|
||||||
|
tubelet_size=2,
|
||||||
|
hidden_size=1024,
|
||||||
|
in_chans=3,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
drop_path_rate=0.0,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
layer_norm_eps=1e-6,
|
||||||
|
qkv_bias=True,
|
||||||
|
attention_probs_dropout_prob=0.0,
|
||||||
|
hidden_act="gelu",
|
||||||
|
initializer_range=0.02,
|
||||||
|
# predictor params
|
||||||
|
pred_hidden_size=384,
|
||||||
|
pred_num_attention_heads=12,
|
||||||
|
pred_num_hidden_layers=12,
|
||||||
|
pred_num_mask_tokens=10,
|
||||||
|
pred_zero_init_mask_tokens=True,
|
||||||
|
pred_mlp_ratio=4.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.frames_per_clip = frames_per_clip
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.tubelet_size = tubelet_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.image_size = crop_size
|
||||||
|
# predictor params
|
||||||
|
self.pred_hidden_size = pred_hidden_size
|
||||||
|
self.pred_num_attention_heads = pred_num_attention_heads
|
||||||
|
self.pred_num_hidden_layers = pred_num_hidden_layers
|
||||||
|
self.pred_num_mask_tokens = pred_num_mask_tokens
|
||||||
|
self.pred_zero_init_mask_tokens = pred_zero_init_mask_tokens
|
||||||
|
self.pred_mlp_ratio = pred_mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VJEPA2Config"]
|
346
src/transformers/models/vjepa2/convert_vjepa2_to_hf.py
Normal file
346
src/transformers/models/vjepa2/convert_vjepa2_to_hf.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import VJEPA2Config, VJEPA2Model, VJEPA2VideoProcessor
|
||||||
|
from transformers.models.vjepa2.modeling_vjepa2 import apply_masks
|
||||||
|
|
||||||
|
|
||||||
|
HUB_REPO = "https://github.com/facebookresearch/vjepa2"
|
||||||
|
HUB_SOURCE = "github"
|
||||||
|
|
||||||
|
HUB_MODELS = {
|
||||||
|
"vit_large": "facebook/vjepa2-vitl-fpc64-256",
|
||||||
|
"vit_huge": "facebook/vjepa2-vith-fpc64-256",
|
||||||
|
"vit_giant": "facebook/vjepa2-vitg-fpc64-256",
|
||||||
|
"vit_giant_384": "facebook/vjepa2-vitg-fpc64-384",
|
||||||
|
}
|
||||||
|
|
||||||
|
S3_MODELS = {
|
||||||
|
"vit_large": "https://dl.fbaipublicfiles.com/vjepa2/vitl.pt",
|
||||||
|
"vit_huge": "https://dl.fbaipublicfiles.com/vjepa2/vith.pt",
|
||||||
|
"vit_giant": "https://dl.fbaipublicfiles.com/vjepa2/vitg.pt",
|
||||||
|
"vit_giant_384": "https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
TOKEN = os.environ.get("HF_TOKEN", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_vjepa2_config(model_name):
|
||||||
|
# size of the architecture
|
||||||
|
if model_name == "vit_large":
|
||||||
|
return VJEPA2Config(
|
||||||
|
crop_size=256,
|
||||||
|
frames_per_clip=64,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
mlp_ratio=4,
|
||||||
|
pred_hidden_size=384,
|
||||||
|
pred_num_attention_heads=12,
|
||||||
|
pred_num_hidden_layers=12,
|
||||||
|
pred_num_mask_tokens=10,
|
||||||
|
)
|
||||||
|
elif model_name == "vit_huge":
|
||||||
|
return VJEPA2Config(
|
||||||
|
crop_size=256,
|
||||||
|
frames_per_clip=64,
|
||||||
|
hidden_size=1280,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
mlp_ratio=4,
|
||||||
|
pred_hidden_size=384,
|
||||||
|
pred_num_attention_heads=12,
|
||||||
|
pred_num_hidden_layers=12,
|
||||||
|
pred_num_mask_tokens=10,
|
||||||
|
)
|
||||||
|
elif model_name == "vit_giant":
|
||||||
|
return VJEPA2Config(
|
||||||
|
crop_size=256,
|
||||||
|
frames_per_clip=64,
|
||||||
|
hidden_size=1408,
|
||||||
|
num_attention_heads=22,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
mlp_ratio=48 / 11,
|
||||||
|
pred_hidden_size=384,
|
||||||
|
pred_num_attention_heads=12,
|
||||||
|
pred_num_hidden_layers=12,
|
||||||
|
pred_num_mask_tokens=10,
|
||||||
|
)
|
||||||
|
elif model_name == "vit_giant_384":
|
||||||
|
return VJEPA2Config(
|
||||||
|
crop_size=384,
|
||||||
|
frames_per_clip=64,
|
||||||
|
hidden_size=1408,
|
||||||
|
num_attention_heads=22,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
mlp_ratio=48 / 11,
|
||||||
|
pred_hidden_size=384,
|
||||||
|
pred_num_attention_heads=12,
|
||||||
|
pred_num_hidden_layers=12,
|
||||||
|
pred_num_mask_tokens=10,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Model not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_encoder_keys(model_state_dict, og_encoder_state_dict, config):
|
||||||
|
emb_dim = config.hidden_size
|
||||||
|
for key, val in og_encoder_state_dict.copy().items():
|
||||||
|
val = og_encoder_state_dict.pop(key)
|
||||||
|
key = key.replace("module.backbone.", "")
|
||||||
|
if key.startswith("blocks."):
|
||||||
|
key = key.replace("blocks.", "encoder.layer.")
|
||||||
|
if "attn." in key:
|
||||||
|
key = key.replace("attn.", "attention.")
|
||||||
|
if key == "pos_embed":
|
||||||
|
key = "encoder.embeddings.position_embeddings"
|
||||||
|
if "patch_embed." in key:
|
||||||
|
key = key.replace("patch_embed.", "encoder.embeddings.patch_embeddings.")
|
||||||
|
if key.startswith("norm."):
|
||||||
|
key = key.replace("norm.", "encoder.layernorm.")
|
||||||
|
if "qkv." in key:
|
||||||
|
prefix, suffix = key.split("qkv")
|
||||||
|
if "bias" in suffix:
|
||||||
|
q_e, k_e, v_e = (
|
||||||
|
val[0:emb_dim],
|
||||||
|
val[emb_dim : emb_dim * 2],
|
||||||
|
val[emb_dim * 2 :],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_e, k_e, v_e = (
|
||||||
|
val[0:emb_dim, :],
|
||||||
|
val[emb_dim : emb_dim * 2, :],
|
||||||
|
val[emb_dim * 2 :, :],
|
||||||
|
)
|
||||||
|
og_encoder_state_dict[prefix + "query" + suffix] = q_e
|
||||||
|
og_encoder_state_dict[prefix + "key" + suffix] = k_e
|
||||||
|
og_encoder_state_dict[prefix + "value" + suffix] = v_e
|
||||||
|
else:
|
||||||
|
og_encoder_state_dict[key] = val
|
||||||
|
return og_encoder_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_predictor_keys(model_state_dict, og_predictor_state_dict, config):
|
||||||
|
emb_dim = config.pred_hidden_size
|
||||||
|
if "predictor_pos_embed" in og_predictor_state_dict:
|
||||||
|
del og_predictor_state_dict["predictor_pos_embed"]
|
||||||
|
# update predictor weights
|
||||||
|
mask_tokens = {}
|
||||||
|
mask_token_keys_to_delete = []
|
||||||
|
for key, val in og_predictor_state_dict.copy().items():
|
||||||
|
val = og_predictor_state_dict.pop(key)
|
||||||
|
key = key.replace("module.backbone.", "")
|
||||||
|
if key.startswith("predictor_blocks."):
|
||||||
|
key = key.replace("predictor_blocks.", "predictor.layer.")
|
||||||
|
if "attn." in key:
|
||||||
|
key = key.replace("attn.", "attention.")
|
||||||
|
if key == "predictor_pos_embed":
|
||||||
|
key = "predictor.embeddings.position_embeddings"
|
||||||
|
if "predictor_embed." in key:
|
||||||
|
key = key.replace("predictor_embed.", "predictor.embeddings.predictor_embeddings.")
|
||||||
|
if "mask_tokens." in key:
|
||||||
|
mask_tokens[key.split("mask_tokens.")[-1]] = val
|
||||||
|
mask_token_keys_to_delete.append(key)
|
||||||
|
# key = key.replace("mask_tokens.", "predictor.embeddings.mask_tokens.")
|
||||||
|
if key.startswith("predictor_norm."):
|
||||||
|
key = key.replace("predictor_norm.", "predictor.layernorm.")
|
||||||
|
if key.startswith("predictor_proj."):
|
||||||
|
key = key.replace("predictor_proj.", "predictor.proj.")
|
||||||
|
if "qkv." in key:
|
||||||
|
prefix, suffix = key.split("qkv")
|
||||||
|
if "bias" in suffix:
|
||||||
|
q_e, k_e, v_e = (
|
||||||
|
val[0:emb_dim],
|
||||||
|
val[emb_dim : emb_dim * 2],
|
||||||
|
val[emb_dim * 2 :],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_e, k_e, v_e = (
|
||||||
|
val[0:emb_dim, :],
|
||||||
|
val[emb_dim : emb_dim * 2, :],
|
||||||
|
val[emb_dim * 2 :, :],
|
||||||
|
)
|
||||||
|
og_predictor_state_dict[prefix + "query" + suffix] = q_e
|
||||||
|
og_predictor_state_dict[prefix + "key" + suffix] = k_e
|
||||||
|
og_predictor_state_dict[prefix + "value" + suffix] = v_e
|
||||||
|
else:
|
||||||
|
og_predictor_state_dict[key] = val
|
||||||
|
mask_tokens = torch.stack([mask_tokens[f"{i}"] for i in range(len(mask_tokens))], dim=0)
|
||||||
|
for k in mask_token_keys_to_delete:
|
||||||
|
del og_predictor_state_dict[k]
|
||||||
|
og_predictor_state_dict["predictor.embeddings.mask_tokens"] = mask_tokens
|
||||||
|
return og_predictor_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_img():
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def upload_original_ckpts(model_name):
|
||||||
|
hf_repo = HUB_MODELS[model_name]
|
||||||
|
original_ckpt = S3_MODELS[model_name]
|
||||||
|
print(f"Uploading original checkpoint for vjepa2 {model_name} to {hf_repo}/original/")
|
||||||
|
with tempfile.NamedTemporaryFile() as fn:
|
||||||
|
local_path = fn.name
|
||||||
|
torch.hub.download_url_to_file(original_ckpt, local_path)
|
||||||
|
api = HfApi()
|
||||||
|
api.upload_file(
|
||||||
|
repo_id=hf_repo,
|
||||||
|
path_or_fileobj=local_path,
|
||||||
|
path_in_repo="original/model.pth",
|
||||||
|
repo_type="model",
|
||||||
|
token=TOKEN,
|
||||||
|
)
|
||||||
|
print("Uploading complete")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_and_test_vjepa2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||||
|
"""
|
||||||
|
Copy/paste/tweak model's weights to our VJEPA2 structure.
|
||||||
|
"""
|
||||||
|
config = get_vjepa2_config(model_name)
|
||||||
|
|
||||||
|
# load original model from torch hub
|
||||||
|
original_encoder, original_predictor = torch.hub.load(HUB_REPO, "vjepa2_" + model_name, source=HUB_SOURCE)
|
||||||
|
original_encoder.eval()
|
||||||
|
original_predictor.eval()
|
||||||
|
original_preprocessor = torch.hub.load(
|
||||||
|
HUB_REPO, "vjepa2_preprocessor", source=HUB_SOURCE, crop_size=config.crop_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# load state_dict of original model, remove and rename some keys
|
||||||
|
encoder_state_dict = original_encoder.state_dict()
|
||||||
|
decoder_state_dict = original_predictor.state_dict()
|
||||||
|
|
||||||
|
model = VJEPA2Model(config).eval()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
og_encoder_sd = convert_encoder_keys(state_dict, encoder_state_dict, config)
|
||||||
|
og_predictor_sd = convert_predictor_keys(state_dict, decoder_state_dict, config)
|
||||||
|
|
||||||
|
og_state_dict = og_encoder_sd
|
||||||
|
og_state_dict.update(og_predictor_sd)
|
||||||
|
model.load_state_dict(og_state_dict)
|
||||||
|
|
||||||
|
# load image
|
||||||
|
image = prepare_img()
|
||||||
|
image = torch.Tensor(np.array(image)).unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
|
print("Input shape: ", image.shape)
|
||||||
|
|
||||||
|
crop_size = config.crop_size
|
||||||
|
processor = VJEPA2VideoProcessor(crop_size=crop_size)
|
||||||
|
pr_out = processor(image, return_tensors="pt")
|
||||||
|
pixel_values_videos = pr_out.pixel_values_videos
|
||||||
|
# run original preprocessor
|
||||||
|
original_pixel_values = original_preprocessor(image)
|
||||||
|
assert original_pixel_values[0].permute(1, 0, 2, 3).shape == pixel_values_videos[0].shape
|
||||||
|
assert torch.allclose(original_pixel_values[0].permute(1, 0, 2, 3), pixel_values_videos[0], atol=1e-3)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# reshape and move to gpu
|
||||||
|
if pixel_values_videos.size(1) == 1:
|
||||||
|
pixel_values_videos = pixel_values_videos.repeat(1, config.frames_per_clip, 1, 1, 1)
|
||||||
|
# pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4) # B x C x T x H x W
|
||||||
|
pixel_values_videos = pixel_values_videos.to(device="cuda", dtype=torch.float32)
|
||||||
|
original_encoder = original_encoder.to(device="cuda", dtype=torch.float32)
|
||||||
|
original_predictor = original_predictor.to(device="cuda", dtype=torch.float32)
|
||||||
|
model = model.to(device="cuda", dtype=torch.float32)
|
||||||
|
# forward
|
||||||
|
original_encoder_outputs = original_encoder(pixel_values_videos.permute(0, 2, 1, 3, 4))
|
||||||
|
B, N, _ = original_encoder_outputs.shape
|
||||||
|
# test full mask
|
||||||
|
context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
|
||||||
|
predictor_mask = context_mask
|
||||||
|
original_predictor_outputs = original_predictor(original_encoder_outputs, context_mask, predictor_mask)
|
||||||
|
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
|
||||||
|
assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3)
|
||||||
|
predictor_outputs = outputs.predictor_output
|
||||||
|
assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3)
|
||||||
|
# test partial mask
|
||||||
|
window_size = 256
|
||||||
|
mask = torch.arange(N, device=pixel_values_videos.device).unsqueeze(0)
|
||||||
|
context_mask = [mask[:, :window_size].repeat((B, 1))]
|
||||||
|
predictor_mask = [mask[:, window_size : window_size * 2].repeat((B, 1))]
|
||||||
|
original_predictor_outputs = original_predictor(
|
||||||
|
apply_masks(original_encoder_outputs, context_mask),
|
||||||
|
context_mask,
|
||||||
|
predictor_mask,
|
||||||
|
)
|
||||||
|
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
|
||||||
|
assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3)
|
||||||
|
predictor_outputs = outputs.predictor_output
|
||||||
|
assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3)
|
||||||
|
|
||||||
|
print("Looks ok!")
|
||||||
|
|
||||||
|
if pytorch_dump_folder_path is not None:
|
||||||
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||||
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||||
|
processor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
name = HUB_MODELS[model_name]
|
||||||
|
model.push_to_hub(name, private=True)
|
||||||
|
processor.push_to_hub(name, private=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
default="vit_large",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"vit_large",
|
||||||
|
"vit_huge",
|
||||||
|
"vit_giant",
|
||||||
|
"vit_giant_384",
|
||||||
|
],
|
||||||
|
help="Name of the model you'd like to convert.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Path to the output PyTorch model directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to push the converted model to the 🤗 hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--upload_original", action="store_true", help="upload the original checkpoint")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_and_test_vjepa2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
||||||
|
if args.upload_original:
|
||||||
|
upload_original_ckpts(args.model_name)
|
903
src/transformers/models/vjepa2/modeling_vjepa2.py
Normal file
903
src/transformers/models/vjepa2/modeling_vjepa2.py
Normal file
@ -0,0 +1,903 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
|
from ...modeling_outputs import BaseModelOutput, dataclass
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||||
|
from .configuration_vjepa2 import VJEPA2Config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VJEPA2WithMaskedInputPredictorOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
VJEPA Predictor outputs that also contains the masked encoder outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
masked_hidden_state (`torch.FloatTensor`), *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
target_hidden_state (`torch.FloatTensor`), *optional*):
|
||||||
|
Returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: torch.FloatTensor
|
||||||
|
masked_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
target_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VJEPA2WithMaskedInputModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
VJEPA outputs that also contains the masked encoder outputs
|
||||||
|
Optionally contains the predictor outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
masked_hidden_state (`torch.FloatTensor`), *optional*):
|
||||||
|
Returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
|
||||||
|
Returns the output from the Predictor module
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: torch.FloatTensor
|
||||||
|
masked_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
predictor_output: Optional[VJEPA2WithMaskedInputPredictorOutput] = None
|
||||||
|
|
||||||
|
def to_tuple(self):
|
||||||
|
output = list(super().to_tuple())
|
||||||
|
if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput):
|
||||||
|
output[-1] = output[-1].to_tuple()
|
||||||
|
return tuple(output)
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2PatchEmbeddings3D(nn.Module):
|
||||||
|
"""
|
||||||
|
Image to Patch Embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: VJEPA2Config,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.tubelet_size = config.tubelet_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
self.proj = nn.Conv3d(
|
||||||
|
in_channels=config.in_chans,
|
||||||
|
out_channels=hidden_size,
|
||||||
|
kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
|
||||||
|
stride=(config.tubelet_size, config.patch_size, config.patch_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def num_patches(config):
|
||||||
|
return (
|
||||||
|
(config.frames_per_clip // config.tubelet_size)
|
||||||
|
* (config.crop_size // config.patch_size)
|
||||||
|
* (config.crop_size // config.patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2Embeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
Construct mask token, position and patch embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: VJEPA2Config, hidden_size: int = 1024):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.num_patches = self.patch_embeddings.num_patches
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_frames = pixel_values_videos.shape[1]
|
||||||
|
|
||||||
|
# Swap `frames` and `channels` dims, the result is:
|
||||||
|
# (batch_size, channels, num_frames, height, width)
|
||||||
|
pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
# For some cases, if the input vision (image/video) consists of num_frames < tubelet_size,
|
||||||
|
# then embedding lookup fails. In these cases, we duplicate the frames.
|
||||||
|
if num_frames < self.config.tubelet_size:
|
||||||
|
pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1)
|
||||||
|
|
||||||
|
target_dtype = self.patch_embeddings.proj.weight.dtype
|
||||||
|
pixel_values_videos = pixel_values_videos.to(dtype=target_dtype)
|
||||||
|
embeddings = self.patch_embeddings(pixel_values_videos)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights * attention_mask
|
||||||
|
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_queries_or_keys(x, pos):
|
||||||
|
B, num_heads, N, D = x.size()
|
||||||
|
|
||||||
|
# similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
# they are computing this every time. instead HF style is to compute the inv_freq once and store it
|
||||||
|
# -- compute angle for each position
|
||||||
|
omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
|
||||||
|
omega /= D / 2.0
|
||||||
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
freq = torch.einsum("..., f -> ... f", pos, omega) # (..., N, D/2), outer product
|
||||||
|
|
||||||
|
# -- build rotation matrix and apply
|
||||||
|
emb_sin = freq.sin() # (..., N, D/2)
|
||||||
|
emb_cos = freq.cos() # (..., N, D/2)
|
||||||
|
|
||||||
|
emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
|
||||||
|
emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
|
||||||
|
|
||||||
|
# --
|
||||||
|
y = x.unflatten(-1, (-1, 2))
|
||||||
|
y1, y2 = y.unbind(dim=-1)
|
||||||
|
|
||||||
|
y = torch.stack((-y2, y1), dim=-1)
|
||||||
|
y = y.flatten(-2)
|
||||||
|
return (x * emb_cos) + (y * emb_sin)
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2RopeAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: VJEPA2Config,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
if hidden_size % num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The hidden size {(hidden_size,)} is not a multiple of the number of attention "
|
||||||
|
f"heads {num_attention_heads}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_head_size = int(hidden_size / num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||||
|
self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||||
|
self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||||
|
|
||||||
|
self.proj = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.dropout_prob = config.attention_probs_dropout_prob
|
||||||
|
self.dropout = nn.Dropout(self.dropout_prob)
|
||||||
|
|
||||||
|
self.grid_size = self.config.crop_size // self.config.patch_size
|
||||||
|
self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size
|
||||||
|
|
||||||
|
self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
|
||||||
|
self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
|
||||||
|
self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
|
||||||
|
|
||||||
|
self.scaling = self.attention_head_size**-0.5
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
new_x_shape = x.size()[:-1] + (
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.attention_head_size,
|
||||||
|
)
|
||||||
|
x = x.view(new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def _get_frame_pos(self, ids):
|
||||||
|
tokens_per_frame = int(self.grid_size * self.grid_size)
|
||||||
|
return ids // tokens_per_frame
|
||||||
|
|
||||||
|
def _get_height_pos(self, ids):
|
||||||
|
# Remove frame component from ids
|
||||||
|
tokens_per_frame = int(self.grid_size * self.grid_size)
|
||||||
|
frame_ids = self._get_frame_pos(ids)
|
||||||
|
ids = ids - tokens_per_frame * frame_ids
|
||||||
|
# --
|
||||||
|
tokens_per_row = self.grid_size
|
||||||
|
return ids // tokens_per_row
|
||||||
|
|
||||||
|
def get_position_ids(self, x, masks=None):
|
||||||
|
device = x.device
|
||||||
|
token_size = x.size(1)
|
||||||
|
|
||||||
|
# Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask,
|
||||||
|
# as 1d vector is broadcasted to the correct shapes.
|
||||||
|
if masks is not None:
|
||||||
|
ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1)
|
||||||
|
else:
|
||||||
|
ids = torch.arange(token_size, device=device)
|
||||||
|
# change to allow for extrapolation
|
||||||
|
tokens_per_frame = int(self.grid_size * self.grid_size)
|
||||||
|
frame_ids = self._get_frame_pos(ids)
|
||||||
|
# --
|
||||||
|
tokens_per_row = self.grid_size
|
||||||
|
height_ids = self._get_height_pos(ids)
|
||||||
|
# --
|
||||||
|
# Remove frame component from ids (1st term) and height component (2nd term)
|
||||||
|
width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
|
||||||
|
return frame_ids, height_ids, width_ids
|
||||||
|
|
||||||
|
def apply_rotary_embeddings(self, qk, pos_ids):
|
||||||
|
d_mask, h_mask, w_mask = pos_ids
|
||||||
|
s = 0
|
||||||
|
qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask)
|
||||||
|
s += self.d_dim
|
||||||
|
qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask)
|
||||||
|
s += self.h_dim
|
||||||
|
qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask)
|
||||||
|
s += self.w_dim
|
||||||
|
# Combine rotated dimension
|
||||||
|
if s < self.attention_head_size:
|
||||||
|
qkr = qk[..., s:]
|
||||||
|
qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
|
||||||
|
else:
|
||||||
|
qk = torch.cat([qkd, qkh, qkw], dim=-1)
|
||||||
|
return qk
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
position_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
pos_ids = self.get_position_ids(hidden_states, masks=position_mask)
|
||||||
|
key_layer = self.apply_rotary_embeddings(key_layer, pos_ids)
|
||||||
|
query_layer = self.apply_rotary_embeddings(query_layer, pos_ids)
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||||
|
logger.warning_once(
|
||||||
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||||
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
context_layer, attention_probs = attention_interface(
|
||||||
|
self,
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
head_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
scaling=self.scaling,
|
||||||
|
dropout=0.0 if not self.training else self.dropout_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = self.proj(context_layer.reshape(new_context_layer_shape))
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.beit.modeling_dinov2.drop_path
|
||||||
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
|
"""
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return input
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = input.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.beit.modeling_beit.BeitDropPath
|
||||||
|
class VJEPA2DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob: Optional[float] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return drop_path(hidden_states, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2MLP(nn.Module):
|
||||||
|
def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0):
|
||||||
|
super().__init__()
|
||||||
|
in_features = out_features = hidden_size
|
||||||
|
hidden_features = int(hidden_size * mlp_ratio)
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
||||||
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.fc1(hidden_state)
|
||||||
|
hidden_state = self.activation(hidden_state)
|
||||||
|
hidden_state = self.fc2(hidden_state)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2Layer(GradientCheckpointingLayer):
|
||||||
|
"""This corresponds to the Block class in the original implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: VJEPA2Config,
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
|
||||||
|
self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||||
|
self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
# Self-Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
position_mask=position_mask, # position mask for context/target selection
|
||||||
|
head_mask=head_mask, # head mask is applied at F.scaled_dot_product_attention
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
hidden_states = self.drop_path(attention_output) + residual
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.drop_path(hidden_states) + residual
|
||||||
|
|
||||||
|
# Add self attentions if we output attention weights
|
||||||
|
outputs = self_attention_outputs[1:]
|
||||||
|
outputs = (hidden_states,) + outputs
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2Encoder(nn.Module):
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
|
||||||
|
drop_path_rates = [
|
||||||
|
(config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.layer = nn.ModuleList(
|
||||||
|
[
|
||||||
|
VJEPA2Layer(
|
||||||
|
config,
|
||||||
|
drop_path_rate=drop_path_rates[i],
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
mlp_ratio=config.mlp_ratio,
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values_videos: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
) -> BaseModelOutput:
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(pixel_values_videos)
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
layer_outputs = layer_module(hidden_states, None, layer_head_mask, output_attentions)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.layernorm(hidden_states)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_masks(x, masks) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
|
||||||
|
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
|
||||||
|
"""
|
||||||
|
all_x = []
|
||||||
|
for m in masks:
|
||||||
|
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
|
||||||
|
all_x += [torch.gather(x, dim=1, index=mask_keep)]
|
||||||
|
|
||||||
|
return torch.cat(all_x, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2PredictorEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
Construct mask token, position and patch embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size)
|
||||||
|
self.num_mask_tokens = 0
|
||||||
|
self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens
|
||||||
|
self.num_mask_tokens = config.pred_num_mask_tokens
|
||||||
|
self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size))
|
||||||
|
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def num_patches(config):
|
||||||
|
if config.frames_per_clip > 1:
|
||||||
|
return (
|
||||||
|
(config.frames_per_clip // config.tubelet_size)
|
||||||
|
* (config.crop_size // config.patch_size)
|
||||||
|
* (config.crop_size // config.patch_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
context_mask: List[torch.Tensor],
|
||||||
|
target_mask: List[torch.Tensor],
|
||||||
|
mask_index: int = 1,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
hidden_states : encoder outputs (context)
|
||||||
|
context_mask: tokens of the context (outputs from the encoder)
|
||||||
|
target_mask: tokens to predict
|
||||||
|
mask_index: index of the target mask to choose (useful for multiclip?)
|
||||||
|
"""
|
||||||
|
|
||||||
|
B = hidden_states.size(0)
|
||||||
|
context = self.predictor_embeddings(hidden_states)
|
||||||
|
|
||||||
|
# Make target tokens
|
||||||
|
mask_index = mask_index % self.num_mask_tokens
|
||||||
|
target = self.mask_tokens[mask_index]
|
||||||
|
|
||||||
|
# Note: this is problematic if the config isn't initialized with the right frames_per_clip value,
|
||||||
|
# e.g. for scenarios if we want to run predictor for more tokens than in the config.
|
||||||
|
# target = target.repeat(B, self.num_patches(self.config), 1)
|
||||||
|
# Remedy: use the provided target mask to get the max patch num
|
||||||
|
max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch
|
||||||
|
target = target.repeat(B, max_patch_num, 1)
|
||||||
|
target = apply_masks(target, target_mask)
|
||||||
|
|
||||||
|
# Concatenate context & target tokens
|
||||||
|
context = context.repeat(len(context_mask), 1, 1)
|
||||||
|
embeddings = torch.cat([context, target], dim=1)
|
||||||
|
|
||||||
|
# Positions of context & target tokens
|
||||||
|
cm = torch.cat(context_mask, dim=0)
|
||||||
|
tm = torch.cat(target_mask, dim=0)
|
||||||
|
masks = torch.cat([cm, tm], dim=1)
|
||||||
|
|
||||||
|
return embeddings, masks
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2Predictor(nn.Module):
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.embeddings = VJEPA2PredictorEmbeddings(config)
|
||||||
|
drop_path_rates = [
|
||||||
|
(
|
||||||
|
config.drop_path_rate * i / (config.pred_num_hidden_layers - 1)
|
||||||
|
if config.pred_num_hidden_layers > 1
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
for i in range(config.pred_num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.layer = nn.ModuleList(
|
||||||
|
[
|
||||||
|
VJEPA2Layer(
|
||||||
|
config,
|
||||||
|
drop_path_rate=drop_path_rates[i],
|
||||||
|
hidden_size=config.pred_hidden_size,
|
||||||
|
num_attention_heads=config.pred_num_attention_heads,
|
||||||
|
mlp_ratio=config.pred_mlp_ratio,
|
||||||
|
)
|
||||||
|
for i in range(config.pred_num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
|
||||||
|
|
||||||
|
def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
|
||||||
|
position_masks = torch.gather(position_masks, dim=1, index=argsort)
|
||||||
|
hidden_states = torch.gather(
|
||||||
|
hidden_states,
|
||||||
|
dim=1,
|
||||||
|
index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
|
||||||
|
)
|
||||||
|
if head_mask is not None and head_mask[0] is not None:
|
||||||
|
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
||||||
|
argsort_4d = (
|
||||||
|
argsort.unsqueeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, head_mask.size(1), head_mask.size(2), -1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(-1, -1, -1, -1, head_mask.size(-1))
|
||||||
|
)
|
||||||
|
head_mask = torch.gather(head_mask, dim=3, index=argsort_4d)
|
||||||
|
argsort_5d = (
|
||||||
|
argsort.unsqueeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, head_mask.size(1), head_mask.size(2), head_mask.size(3), -1)
|
||||||
|
)
|
||||||
|
head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
|
||||||
|
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
||||||
|
return hidden_states, position_masks, head_mask
|
||||||
|
|
||||||
|
def unsort_tokens(self, hidden_states, argsort):
|
||||||
|
reverse_argsort = torch.argsort(argsort, dim=1)
|
||||||
|
hidden_states = torch.gather(
|
||||||
|
hidden_states,
|
||||||
|
dim=1,
|
||||||
|
index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
context_mask: List[torch.Tensor],
|
||||||
|
target_mask: List[torch.Tensor],
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
) -> BaseModelOutput:
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
# mask out the encoder hidden states
|
||||||
|
# this is implemented here as in VJEPA training a separate encoder is used for target
|
||||||
|
encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask)
|
||||||
|
_, N_ctxt, D = encoder_hidden_states.shape
|
||||||
|
hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask)
|
||||||
|
|
||||||
|
# Put tokens in sorted order
|
||||||
|
argsort = torch.argsort(position_masks, dim=1) # [B, N]
|
||||||
|
hidden_states, position_masks, head_mask = self.sort_tokens(hidden_states, position_masks, argsort, head_mask)
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
layer_outputs = layer_module(hidden_states, position_masks, layer_head_mask, output_attentions)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
hidden_states = self.layernorm(hidden_states)
|
||||||
|
# unsort and extract the predicted tokens
|
||||||
|
hidden_states = self.unsort_tokens(hidden_states, argsort)
|
||||||
|
hidden_states = hidden_states[:, N_ctxt:]
|
||||||
|
# projection
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class VJEPA2PreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = VJEPA2Config
|
||||||
|
base_model_prefix = "vjepa2"
|
||||||
|
main_input_name = "pixel_values_videos"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["VJEPA2Layer"]
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
|
def _init_weights(
|
||||||
|
self,
|
||||||
|
module: Union[
|
||||||
|
nn.Linear,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.LayerNorm,
|
||||||
|
VJEPA2Embeddings,
|
||||||
|
VJEPA2PredictorEmbeddings,
|
||||||
|
],
|
||||||
|
):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
||||||
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||||
|
# `trunc_normal_cpu` not implemented in `half` issues
|
||||||
|
module.weight.data = nn.init.trunc_normal_(
|
||||||
|
module.weight.data.to(torch.float32),
|
||||||
|
mean=0.0,
|
||||||
|
std=self.config.initializer_range,
|
||||||
|
).to(module.weight.dtype)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
elif isinstance(module, VJEPA2PredictorEmbeddings):
|
||||||
|
if not module.zero_init_mask_tokens:
|
||||||
|
module.mask_token = nn.init.trunc_normal_(
|
||||||
|
module.mask_token.to(torch.float32),
|
||||||
|
mean=0.0,
|
||||||
|
std=self.config.initializer_range,
|
||||||
|
).to(module.mask_token.dtype)
|
||||||
|
else:
|
||||||
|
module.mask_tokens.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
|
||||||
|
"""
|
||||||
|
Inputs:
|
||||||
|
- head_mask: bsz x seq_length x seq_length | None
|
||||||
|
Returns
|
||||||
|
- [num_hidden_layers x batch x num_heads x seq_length x seq_length] | [num_hidden_layers]
|
||||||
|
"""
|
||||||
|
if head_mask is not None:
|
||||||
|
head_mask = head_mask.unsqueeze(1).unsqueeze(0)
|
||||||
|
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
||||||
|
else:
|
||||||
|
head_mask = [None] * num_hidden_layers
|
||||||
|
return head_mask
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class VJEPA2Model(VJEPA2PreTrainedModel):
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.encoder = VJEPA2Encoder(config)
|
||||||
|
self.predictor = VJEPA2Predictor(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D:
|
||||||
|
return self.encoder.embeddings.patch_embeddings
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values_videos: torch.Tensor,
|
||||||
|
context_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
context_mask: Optional[List[torch.Tensor]] = None,
|
||||||
|
target_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
target_mask: Optional[List[torch.Tensor]] = None,
|
||||||
|
skip_predictor: bool = False,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) -> VJEPA2WithMaskedInputModelOutput:
|
||||||
|
r"""
|
||||||
|
pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`, required):
|
||||||
|
The input video pixels which is processed by VJEPA2VideoProcessor.
|
||||||
|
context_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
|
||||||
|
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the context.
|
||||||
|
target_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
|
||||||
|
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the target.
|
||||||
|
context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
|
||||||
|
The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
|
||||||
|
By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
|
||||||
|
available to the predictor.
|
||||||
|
target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
|
||||||
|
The mask position ids indicating which encoder output patches are going to be used as a prediction target
|
||||||
|
for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
|
||||||
|
that the predictor should predict all encoder patches.
|
||||||
|
skip_predictor (bool):
|
||||||
|
flag to skip the predictor forward, useful if you just need the encoder outputs
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
if pixel_values_videos is None:
|
||||||
|
raise ValueError("You have to specify pixel_values_videos")
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
context_head_mask = _convert_head_mask_to_5d(context_head_mask, self.config.num_hidden_layers)
|
||||||
|
target_head_mask = _convert_head_mask_to_5d(target_head_mask, self.config.pred_num_hidden_layers)
|
||||||
|
|
||||||
|
encoder_outputs: BaseModelOutput = self.encoder(
|
||||||
|
pixel_values_videos=pixel_values_videos,
|
||||||
|
head_mask=context_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs.last_hidden_state
|
||||||
|
|
||||||
|
if context_mask is None and target_mask is None:
|
||||||
|
B = pixel_values_videos.size(0)
|
||||||
|
N = sequence_output.size(1) # ensure we are using dynamic patch size
|
||||||
|
context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
|
||||||
|
target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
|
||||||
|
|
||||||
|
if not skip_predictor:
|
||||||
|
predictor_outputs: BaseModelOutput = self.predictor(
|
||||||
|
encoder_hidden_states=sequence_output,
|
||||||
|
context_mask=context_mask,
|
||||||
|
target_mask=target_mask,
|
||||||
|
head_mask=target_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
predictor_output = VJEPA2WithMaskedInputPredictorOutput(
|
||||||
|
last_hidden_state=predictor_outputs.last_hidden_state,
|
||||||
|
target_hidden_state=apply_masks(sequence_output, target_mask),
|
||||||
|
hidden_states=predictor_outputs.hidden_states,
|
||||||
|
attentions=predictor_outputs.attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
predictor_output = None
|
||||||
|
|
||||||
|
encoder_output = VJEPA2WithMaskedInputModelOutput(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
masked_hidden_state=apply_masks(sequence_output, context_mask),
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
predictor_output=predictor_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoder_output
|
||||||
|
|
||||||
|
def get_vision_features(self, pixel_values_videos) -> torch.Tensor:
|
||||||
|
encoder_output = self.forward(pixel_values_videos)
|
||||||
|
return encoder_output.last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"]
|
59
src/transformers/models/vjepa2/video_processing_vjepa2.py
Normal file
59
src/transformers/models/vjepa2/video_processing_vjepa2.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast Video processor class for VJEPA2."""
|
||||||
|
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_DEFAULT_MEAN,
|
||||||
|
IMAGENET_DEFAULT_STD,
|
||||||
|
)
|
||||||
|
from ...processing_utils import Unpack, VideosKwargs
|
||||||
|
from ...utils import is_vision_available
|
||||||
|
from ...utils.import_utils import requires
|
||||||
|
from ...video_processing_utils import BaseVideoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from ...image_utils import PILImageResampling
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2VideoProcessorInitKwargs(VideosKwargs): ...
|
||||||
|
|
||||||
|
|
||||||
|
@requires(backends=("torchvision",))
|
||||||
|
class VJEPA2VideoProcessor(BaseVideoProcessor):
|
||||||
|
resample = PILImageResampling.BILINEAR
|
||||||
|
image_mean = IMAGENET_DEFAULT_MEAN
|
||||||
|
image_std = IMAGENET_DEFAULT_STD
|
||||||
|
size = {"shortest_edge": int(256 * 256 / 224)}
|
||||||
|
crop_size = 256
|
||||||
|
do_resize = True
|
||||||
|
do_rescale = True
|
||||||
|
do_center_crop = True
|
||||||
|
do_normalize = True
|
||||||
|
valid_kwargs = VJEPA2VideoProcessorInitKwargs
|
||||||
|
model_input_names = ["pixel_values_videos"]
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Unpack[VJEPA2VideoProcessorInitKwargs]):
|
||||||
|
crop_size = kwargs.get("crop_size", 256)
|
||||||
|
if not isinstance(crop_size, int):
|
||||||
|
if not isinstance(crop_size, dict) or "height" not in crop_size:
|
||||||
|
raise ValueError("crop_size must be an integer or a dictionary with a 'height' key")
|
||||||
|
crop_size = crop_size["height"]
|
||||||
|
resize_size = int(crop_size * 256 / 224)
|
||||||
|
kwargs["size"] = {"shortest_edge": resize_size}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VJEPA2VideoProcessor"]
|
@ -166,6 +166,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||||||
"t5",
|
"t5",
|
||||||
"trocr",
|
"trocr",
|
||||||
"vit",
|
"vit",
|
||||||
|
"vjepa2",
|
||||||
"xglm",
|
"xglm",
|
||||||
"wav2vec2",
|
"wav2vec2",
|
||||||
# "xlnet",
|
# "xlnet",
|
||||||
|
0
tests/models/vjepa2/__init__.py
Normal file
0
tests/models/vjepa2/__init__.py
Normal file
345
tests/models/vjepa2/test_modeling_vjepa2.py
Normal file
345
tests/models/vjepa2/test_modeling_vjepa2.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch V-JEPA2 model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import VJEPA2Config
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_flaky,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
from ...test_video_processing_common import (
|
||||||
|
prepare_video_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers import VJEPA2Model
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import AutoVideoProcessor
|
||||||
|
|
||||||
|
VJEPA_HF_MODEL = "facebook/vjepa2-vitl-fpc64-256"
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2ModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=2,
|
||||||
|
image_size=16,
|
||||||
|
patch_size=16,
|
||||||
|
num_channels=3,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_frames=2,
|
||||||
|
mlp_ratio=1,
|
||||||
|
pred_hidden_size=32,
|
||||||
|
pred_num_attention_heads=2,
|
||||||
|
pred_num_hidden_layers=2,
|
||||||
|
pred_num_mask_tokens=10,
|
||||||
|
is_training=False,
|
||||||
|
attn_implementation="sdpa",
|
||||||
|
mask_ratio=0.5,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.pred_hidden_size = pred_hidden_size
|
||||||
|
self.pred_num_attention_heads = pred_num_attention_heads
|
||||||
|
self.pred_num_hidden_layers = pred_num_hidden_layers
|
||||||
|
self.pred_num_mask_tokens = pred_num_mask_tokens
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
self.is_training = is_training
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
|
|
||||||
|
num_patches = ((image_size // patch_size) ** 2) * (num_frames // 2)
|
||||||
|
self.seq_length = num_patches
|
||||||
|
self.num_masks = int(self.mask_ratio * self.seq_length)
|
||||||
|
self.mask_length = num_patches
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values_videos = floats_tensor(
|
||||||
|
[
|
||||||
|
self.batch_size,
|
||||||
|
self.num_frames,
|
||||||
|
self.num_channels,
|
||||||
|
self.image_size,
|
||||||
|
self.image_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values_videos
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return VJEPA2Config(
|
||||||
|
crop_size=self.image_size,
|
||||||
|
frames_per_clip=self.num_frames,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
pred_hidden_size=self.pred_hidden_size,
|
||||||
|
pred_num_attention_heads=self.pred_num_attention_heads,
|
||||||
|
pred_num_hidden_layers=self.pred_num_hidden_layers,
|
||||||
|
pred_num_mask_tokens=self.pred_num_mask_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values_videos):
|
||||||
|
model = VJEPA2Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values_videos)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape,
|
||||||
|
(self.batch_size, self.seq_length, self.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
pixel_values_videos,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values_videos": pixel_values_videos}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Here we also overwrite some of the tests of test_modeling_common.py, as VJEPA2 does not use input_ids, inputs_embeds,
|
||||||
|
attention_mask and seq_length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
test_torch_exportable = True
|
||||||
|
|
||||||
|
all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
|
||||||
|
|
||||||
|
fx_compatible = True
|
||||||
|
|
||||||
|
pipeline_model_mapping = {}
|
||||||
|
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = VJEPA2ModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=VJEPA2Config, has_text_modality=False, hidden_size=37)
|
||||||
|
|
||||||
|
@is_flaky(max_attempts=3, description="`torch.nn.init.trunc_normal_` is flaky.")
|
||||||
|
def test_initialization(self):
|
||||||
|
super().test_initialization()
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@unittest.skip(reason="VJEPA2 does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="VJEPA2 does not support feedforward chunking yet")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img():
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_random_video(image_size=256):
|
||||||
|
videos = prepare_video_inputs(
|
||||||
|
batch_size=1,
|
||||||
|
num_frames=16,
|
||||||
|
num_channels=3,
|
||||||
|
min_resolution=image_size,
|
||||||
|
max_resolution=image_size,
|
||||||
|
equal_resolution=True,
|
||||||
|
return_tensors="torch",
|
||||||
|
)
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class VJEPA2ModelIntegrationTest(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def default_video_processor(self):
|
||||||
|
return AutoVideoProcessor.from_pretrained(VJEPA_HF_MODEL) if is_vision_available() else None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_image(self):
|
||||||
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
||||||
|
|
||||||
|
video_processor = self.default_video_processor
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = video_processor(torch.Tensor(np.array(image)), return_tensors="pt").to(torch_device)
|
||||||
|
pixel_values_videos = inputs.pixel_values_videos
|
||||||
|
pixel_values_videos = pixel_values_videos.repeat(1, model.config.frames_per_clip, 1, 1, 1)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values_videos)
|
||||||
|
|
||||||
|
# verify the last hidden states
|
||||||
|
expected_shape = torch.Size((1, 8192, 1024))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_video(self):
|
||||||
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
||||||
|
|
||||||
|
video_processor = self.default_video_processor
|
||||||
|
video = prepare_random_video()
|
||||||
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
||||||
|
pixel_values_videos = inputs.pixel_values_videos
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values_videos)
|
||||||
|
|
||||||
|
# verify the last hidden states
|
||||||
|
expected_shape = torch.Size((1, 2048, 1024))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_predictor_outputs(self):
|
||||||
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
||||||
|
|
||||||
|
video_processor = self.default_video_processor
|
||||||
|
video = prepare_random_video()
|
||||||
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
||||||
|
pixel_values_videos = inputs.pixel_values_videos
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values_videos)
|
||||||
|
|
||||||
|
# verify the last hidden states
|
||||||
|
expected_shape = torch.Size((1, 2048, 1024))
|
||||||
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_predictor_full_mask(self):
|
||||||
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
||||||
|
|
||||||
|
video_processor = self.default_video_processor
|
||||||
|
video = prepare_random_video()
|
||||||
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
||||||
|
pixel_values_videos = inputs.pixel_values_videos
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
context_mask = [torch.arange(2048, device=pixel_values_videos.device).unsqueeze(0)]
|
||||||
|
predictor_mask = context_mask
|
||||||
|
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
|
||||||
|
|
||||||
|
# verify the last hidden states
|
||||||
|
expected_shape = torch.Size((1, 2048, 1024))
|
||||||
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_predictor_partial_mask(self):
|
||||||
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
||||||
|
|
||||||
|
video_processor = self.default_video_processor
|
||||||
|
video = prepare_random_video()
|
||||||
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
||||||
|
pixel_values_videos = inputs.pixel_values_videos
|
||||||
|
|
||||||
|
num_patches = 2048
|
||||||
|
num_masks = 100
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
pos_ids = torch.arange(num_patches, device=pixel_values_videos.device)
|
||||||
|
context_mask = [pos_ids[0 : num_patches - num_masks].unsqueeze(0)]
|
||||||
|
predictor_mask = [pos_ids[num_patches - num_masks :].unsqueeze(0)]
|
||||||
|
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
|
||||||
|
|
||||||
|
# verify the last hidden states
|
||||||
|
expected_shape = torch.Size((1, num_masks, 1024))
|
||||||
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
@ -1301,6 +1301,7 @@ class ModelTesterMixin:
|
|||||||
"input_values",
|
"input_values",
|
||||||
"inputs_embeds",
|
"inputs_embeds",
|
||||||
"pixel_values",
|
"pixel_values",
|
||||||
|
"pixel_values_videos",
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
"visual_feats",
|
"visual_feats",
|
||||||
"visual_pos",
|
"visual_pos",
|
||||||
|
Loading…
Reference in New Issue
Block a user