mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests --------- Co-authored-by: Koustuv Sinha <koustuv.sinha@mail.mcgill.ca> Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: Koustuv Sinha <koustuvsinha@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
346 lines
12 KiB
Python
346 lines
12 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch V-JEPA2 model."""
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers import VJEPA2Config
|
|
from transformers.testing_utils import (
|
|
is_flaky,
|
|
require_torch,
|
|
require_vision,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
from ...test_video_processing_common import (
|
|
prepare_video_inputs,
|
|
)
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
from torch import nn
|
|
|
|
from transformers import VJEPA2Model
|
|
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
from transformers import AutoVideoProcessor
|
|
|
|
VJEPA_HF_MODEL = "facebook/vjepa2-vitl-fpc64-256"
|
|
|
|
|
|
class VJEPA2ModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
image_size=16,
|
|
patch_size=16,
|
|
num_channels=3,
|
|
hidden_size=32,
|
|
num_hidden_layers=4,
|
|
num_attention_heads=2,
|
|
num_frames=2,
|
|
mlp_ratio=1,
|
|
pred_hidden_size=32,
|
|
pred_num_attention_heads=2,
|
|
pred_num_hidden_layers=2,
|
|
pred_num_mask_tokens=10,
|
|
is_training=False,
|
|
attn_implementation="sdpa",
|
|
mask_ratio=0.5,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.num_channels = num_channels
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.num_frames = num_frames
|
|
self.mlp_ratio = mlp_ratio
|
|
self.pred_hidden_size = pred_hidden_size
|
|
self.pred_num_attention_heads = pred_num_attention_heads
|
|
self.pred_num_hidden_layers = pred_num_hidden_layers
|
|
self.pred_num_mask_tokens = pred_num_mask_tokens
|
|
self.attn_implementation = attn_implementation
|
|
self.is_training = is_training
|
|
self.mask_ratio = mask_ratio
|
|
|
|
num_patches = ((image_size // patch_size) ** 2) * (num_frames // 2)
|
|
self.seq_length = num_patches
|
|
self.num_masks = int(self.mask_ratio * self.seq_length)
|
|
self.mask_length = num_patches
|
|
|
|
def prepare_config_and_inputs(self):
|
|
pixel_values_videos = floats_tensor(
|
|
[
|
|
self.batch_size,
|
|
self.num_frames,
|
|
self.num_channels,
|
|
self.image_size,
|
|
self.image_size,
|
|
]
|
|
)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, pixel_values_videos
|
|
|
|
def get_config(self):
|
|
return VJEPA2Config(
|
|
crop_size=self.image_size,
|
|
frames_per_clip=self.num_frames,
|
|
hidden_size=self.hidden_size,
|
|
num_attention_heads=self.num_attention_heads,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
mlp_ratio=self.mlp_ratio,
|
|
pred_hidden_size=self.pred_hidden_size,
|
|
pred_num_attention_heads=self.pred_num_attention_heads,
|
|
pred_num_hidden_layers=self.pred_num_hidden_layers,
|
|
pred_num_mask_tokens=self.pred_num_mask_tokens,
|
|
)
|
|
|
|
def create_and_check_model(self, config, pixel_values_videos):
|
|
model = VJEPA2Model(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(pixel_values_videos)
|
|
self.parent.assertEqual(
|
|
result.last_hidden_state.shape,
|
|
(self.batch_size, self.seq_length, self.hidden_size),
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
pixel_values_videos,
|
|
) = config_and_inputs
|
|
inputs_dict = {"pixel_values_videos": pixel_values_videos}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
"""
|
|
Here we also overwrite some of the tests of test_modeling_common.py, as VJEPA2 does not use input_ids, inputs_embeds,
|
|
attention_mask and seq_length.
|
|
"""
|
|
|
|
test_torch_exportable = True
|
|
|
|
all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
|
|
|
|
fx_compatible = True
|
|
|
|
pipeline_model_mapping = {}
|
|
|
|
test_pruning = False
|
|
test_resize_embeddings = False
|
|
test_head_masking = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = VJEPA2ModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=VJEPA2Config, has_text_modality=False, hidden_size=37)
|
|
|
|
@is_flaky(max_attempts=3, description="`torch.nn.init.trunc_normal_` is flaky.")
|
|
def test_initialization(self):
|
|
super().test_initialization()
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
@unittest.skip(reason="VJEPA2 does not use inputs_embeds")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
def test_model_get_set_embeddings(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
|
x = model.get_output_embeddings()
|
|
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
@unittest.skip(reason="VJEPA2 does not support feedforward chunking yet")
|
|
def test_feed_forward_chunking(self):
|
|
pass
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL)
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
# We will verify our results on an image of cute cats
|
|
def prepare_img():
|
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
|
return image
|
|
|
|
|
|
def prepare_random_video(image_size=256):
|
|
videos = prepare_video_inputs(
|
|
batch_size=1,
|
|
num_frames=16,
|
|
num_channels=3,
|
|
min_resolution=image_size,
|
|
max_resolution=image_size,
|
|
equal_resolution=True,
|
|
return_tensors="torch",
|
|
)
|
|
return videos
|
|
|
|
|
|
@require_torch
|
|
@require_vision
|
|
class VJEPA2ModelIntegrationTest(unittest.TestCase):
|
|
@cached_property
|
|
def default_video_processor(self):
|
|
return AutoVideoProcessor.from_pretrained(VJEPA_HF_MODEL) if is_vision_available() else None
|
|
|
|
@slow
|
|
def test_inference_image(self):
|
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
|
|
|
video_processor = self.default_video_processor
|
|
image = prepare_img()
|
|
inputs = video_processor(torch.Tensor(np.array(image)), return_tensors="pt").to(torch_device)
|
|
pixel_values_videos = inputs.pixel_values_videos
|
|
pixel_values_videos = pixel_values_videos.repeat(1, model.config.frames_per_clip, 1, 1, 1)
|
|
|
|
# forward pass
|
|
with torch.no_grad():
|
|
outputs = model(pixel_values_videos)
|
|
|
|
# verify the last hidden states
|
|
expected_shape = torch.Size((1, 8192, 1024))
|
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
|
|
|
expected_slice = torch.tensor(
|
|
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
|
|
device=torch_device,
|
|
)
|
|
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
|
|
|
|
@slow
|
|
def test_inference_video(self):
|
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
|
|
|
video_processor = self.default_video_processor
|
|
video = prepare_random_video()
|
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
|
pixel_values_videos = inputs.pixel_values_videos
|
|
|
|
# forward pass
|
|
with torch.no_grad():
|
|
outputs = model(pixel_values_videos)
|
|
|
|
# verify the last hidden states
|
|
expected_shape = torch.Size((1, 2048, 1024))
|
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
|
|
|
@slow
|
|
def test_predictor_outputs(self):
|
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
|
|
|
video_processor = self.default_video_processor
|
|
video = prepare_random_video()
|
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
|
pixel_values_videos = inputs.pixel_values_videos
|
|
|
|
# forward pass
|
|
with torch.no_grad():
|
|
outputs = model(pixel_values_videos)
|
|
|
|
# verify the last hidden states
|
|
expected_shape = torch.Size((1, 2048, 1024))
|
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
|
|
|
@slow
|
|
def test_predictor_full_mask(self):
|
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
|
|
|
video_processor = self.default_video_processor
|
|
video = prepare_random_video()
|
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
|
pixel_values_videos = inputs.pixel_values_videos
|
|
|
|
# forward pass
|
|
with torch.no_grad():
|
|
context_mask = [torch.arange(2048, device=pixel_values_videos.device).unsqueeze(0)]
|
|
predictor_mask = context_mask
|
|
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
|
|
|
|
# verify the last hidden states
|
|
expected_shape = torch.Size((1, 2048, 1024))
|
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
|
|
|
@slow
|
|
def test_predictor_partial_mask(self):
|
|
model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
|
|
|
|
video_processor = self.default_video_processor
|
|
video = prepare_random_video()
|
|
inputs = video_processor(video, return_tensors="pt").to(torch_device)
|
|
pixel_values_videos = inputs.pixel_values_videos
|
|
|
|
num_patches = 2048
|
|
num_masks = 100
|
|
# forward pass
|
|
with torch.no_grad():
|
|
pos_ids = torch.arange(num_patches, device=pixel_values_videos.device)
|
|
context_mask = [pos_ids[0 : num_patches - num_masks].unsqueeze(0)]
|
|
predictor_mask = [pos_ids[num_patches - num_masks :].unsqueeze(0)]
|
|
outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
|
|
|
|
# verify the last hidden states
|
|
expected_shape = torch.Size((1, num_masks, 1024))
|
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|