mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add V-JEPA for video classification model (#38788)
* adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * Add ForVideoClassification model * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * attention head * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests * Initial cls head setup * remove rop attention from head (not needed) * remove swigluffn - not needed * Add siglip layer * Replace with siglip layer * Rename Siglip - VJEPA2 * remove unused modules * remove siglip mlp * nit * remove MLP * Refactor head cross attention * refactor VJEPA2HeadCrossAttentionLayer * nit renaming * fixup * remove commented code * Add cls head params to config * depth from config * move pooler + classifier to the model * Update for cls model signature * move layers, rename a bit * fix docs * update weights init * remove typehint for init * add to auto-mapping * enable tests * Add conversion script * fixup * add to docs * fix docs * nit * refactor for mapping * clean * Add integration test * Fixing multi gpu test * update not-split-modules * update video cls test tolerance * Increase test_inference_image tolerance * Update no-split modules for multi gpu * Apply suggestions from code review * fixing multi-gpu * fix docstring * Add cls snippet to docs * Update checkpoint
This commit is contained in:
parent
2ff964bcb4
commit
9bec2654ed
@ -38,7 +38,7 @@ This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yoni
|
||||
|
||||
## Usage example
|
||||
|
||||
The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
|
||||
The snippet below shows how to load the V-JEPA 2 model for feature extraction using the `AutoModel` class.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@ -68,6 +68,43 @@ encoder_outputs = outputs.last_hidden_state
|
||||
predictor_outputs = outputs.predictor_output.last_hidden_state
|
||||
```
|
||||
|
||||
V-JEPA 2 can also be finetuned for video classification. In the following snippet, we show how use finetuned on Something-Something-V2 video classification model.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
from transformers import AutoVideoProcessor, AutoModelForVideoClassification
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Load model and video preprocessor
|
||||
hf_repo = "facebook/vjepa2-vitl-fpc16-256-ssv2"
|
||||
|
||||
model = AutoModelForVideoClassification.from_pretrained(hf_repo).to(device)
|
||||
processor = AutoVideoProcessor.from_pretrained(hf_repo)
|
||||
|
||||
# To load a video, sample the number of frames according to the model.
|
||||
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
|
||||
vr = VideoDecoder(video_url)
|
||||
frame_idx = np.arange(0, model.config.frames_per_clip, 8) # you can define more complex sampling strategy
|
||||
video = vr.get_frames_at(indices=frame_idx).data # frames x channels x height x width
|
||||
|
||||
# Preprocess and run inference
|
||||
inputs = processor(video, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
print("Top 5 predicted class names:")
|
||||
top5_indices = logits.topk(5).indices[0]
|
||||
top5_probs = torch.softmax(logits, dim=-1).topk(5).values[0]
|
||||
for idx, prob in zip(top5_indices, top5_probs):
|
||||
text_label = model.config.id2label[idx.item()]
|
||||
print(f" - {text_label}: {prob:.2f}")
|
||||
```
|
||||
|
||||
## VJEPA2Config
|
||||
|
||||
[[autodoc]] VJEPA2Config
|
||||
@ -77,6 +114,11 @@ predictor_outputs = outputs.predictor_output.last_hidden_state
|
||||
[[autodoc]] VJEPA2Model
|
||||
- forward
|
||||
|
||||
## VJEPA2ForVideoClassification
|
||||
|
||||
[[autodoc]] VJEPA2ForVideoClassification
|
||||
- forward
|
||||
|
||||
## VJEPA2VideoProcessor
|
||||
|
||||
[[autodoc]] VJEPA2VideoProcessor
|
||||
|
@ -150,6 +150,7 @@ LOSS_MAPPING = {
|
||||
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
||||
"ForSequenceClassification": ForSequenceClassificationLoss,
|
||||
"ForImageClassification": ForSequenceClassificationLoss,
|
||||
"ForVideoClassification": ForSequenceClassificationLoss,
|
||||
"ForTokenClassification": ForTokenClassification,
|
||||
"ForSegmentation": ForSegmentationLoss,
|
||||
"ForObjectDetection": ForObjectDetectionLoss,
|
||||
|
@ -844,6 +844,7 @@ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("timesformer", "TimesformerForVideoClassification"),
|
||||
("videomae", "VideoMAEForVideoClassification"),
|
||||
("vivit", "VivitForVideoClassification"),
|
||||
("vjepa2", "VJEPA2ForVideoClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -60,6 +60,10 @@ class VJEPA2Config(PretrainedConfig):
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for attentions.
|
||||
num_pooler_layers (`int`, *optional*, defaults to 3):
|
||||
The number of self-attention layers in the pooler.
|
||||
pred_hidden_size (`int`, *optional*, defaults to 384):
|
||||
Dimensionality of the predictor layers
|
||||
pred_num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
@ -107,6 +111,8 @@ class VJEPA2Config(PretrainedConfig):
|
||||
attention_probs_dropout_prob=0.0,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
attention_dropout=0.0,
|
||||
num_pooler_layers=3,
|
||||
# predictor params
|
||||
pred_hidden_size=384,
|
||||
pred_num_attention_heads=12,
|
||||
@ -134,6 +140,8 @@ class VJEPA2Config(PretrainedConfig):
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.image_size = crop_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.num_pooler_layers = num_pooler_layers
|
||||
# predictor params
|
||||
self.pred_hidden_size = pred_hidden_size
|
||||
self.pred_num_attention_heads = pred_num_attention_heads
|
||||
|
@ -0,0 +1,220 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from decord import VideoReader
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from transformers import VJEPA2ForVideoClassification, VJEPA2VideoProcessor
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def get_video():
|
||||
path = hf_hub_download(
|
||||
repo_id="nateraw/kinetics-mini",
|
||||
filename="val/bowling/-WH-lxmGJVY_000005_000015.mp4",
|
||||
repo_type="dataset",
|
||||
)
|
||||
video_reader = VideoReader(path)
|
||||
return video_reader
|
||||
|
||||
|
||||
CLASSIFIERS = {
|
||||
# Something-Something-v2 dataset
|
||||
"vjepa2-vitl-fpc16-256-ssv2": {
|
||||
"base_model": "facebook/vjepa2-vitl-fpc64-256",
|
||||
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitl-16x2x3.pt",
|
||||
"num_labels": 174,
|
||||
"frames_per_clip": 16,
|
||||
"dataset": "something-something-v2",
|
||||
"result": (145, 0.30867, "Stuffing [something] into [something]"),
|
||||
},
|
||||
"vjepa2-vitg-fpc64-384-ssv2": {
|
||||
"base_model": "facebook/vjepa2-vitg-fpc64-384",
|
||||
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt",
|
||||
"frames_per_clip": 64,
|
||||
"num_labels": 174,
|
||||
"dataset": "something-something-v2",
|
||||
"result": (112, 0.26408, "Putting [something] onto [something]"),
|
||||
},
|
||||
# Diving48 dataset
|
||||
"vjepa2-vitl-fpc32-256-diving48": {
|
||||
"base_model": "facebook/vjepa2-vitl-fpc64-256",
|
||||
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitl-256.pt",
|
||||
"num_labels": 48,
|
||||
"frames_per_clip": 32,
|
||||
"dataset": "diving48",
|
||||
"result": (35, 0.32875, "['Inward', '35som', 'NoTwis', 'TUCK']"),
|
||||
},
|
||||
"vjepa2-vitg-fpc32-384-diving48": {
|
||||
"base_model": "facebook/vjepa2-vitg-fpc64-384",
|
||||
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitg-384-32x4x3.pt",
|
||||
"frames_per_clip": 32,
|
||||
"num_labels": 48,
|
||||
"dataset": "diving48",
|
||||
"result": (22, 0.35351, "['Forward', '25som', '2Twis', 'PIKE']"),
|
||||
},
|
||||
}
|
||||
|
||||
# fmt: off
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"module.pooler.query_tokens": r"pooler.query_tokens",
|
||||
r"module.pooler.cross_attention_block.norm(\d+).": r"pooler.cross_attention_layer.layer_norm\1.",
|
||||
r"module.pooler.cross_attention_block.xattn.(q|k|v).": r"pooler.cross_attention_layer.cross_attn.\1_proj.",
|
||||
r"module.pooler.cross_attention_block.mlp.fc(\d+).": r"pooler.cross_attention_layer.mlp.fc\1.",
|
||||
r"module.pooler.blocks.(\d+).norm(\d+).": r"pooler.self_attention_layers.\1.layer_norm\2.",
|
||||
r"module.pooler.blocks.(\d+).attn.(q|k|v).": r"pooler.self_attention_layers.\1.self_attn.\2_proj.",
|
||||
r"module.pooler.blocks.(\d+).attn.proj.": r"pooler.self_attention_layers.\1.self_attn.out_proj.",
|
||||
r"module.pooler.blocks.(\d+).mlp.fc(\d+).": r"pooler.self_attention_layers.\1.mlp.fc\2.",
|
||||
r"module.linear.": r"classifier.",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def get_id2label_mapping(dataset_name: str) -> dict[int, str]:
|
||||
path = hf_hub_download(
|
||||
repo_id="huggingface/label-files",
|
||||
filename=f"{dataset_name}-id2label.json",
|
||||
repo_type="dataset",
|
||||
)
|
||||
with open(path, "r") as f:
|
||||
id2label = json.load(f)
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
return id2label
|
||||
|
||||
|
||||
def split_qkv(state_dict):
|
||||
state_dict = state_dict.copy()
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if ".qkv." in key:
|
||||
tensor = state_dict.pop(key)
|
||||
q, k, v = torch.chunk(tensor, 3, dim=0)
|
||||
state_dict[key.replace(".qkv.", ".q.")] = q
|
||||
state_dict[key.replace(".qkv.", ".k.")] = k
|
||||
state_dict[key.replace(".qkv.", ".v.")] = v
|
||||
elif ".kv." in key:
|
||||
tensor = state_dict.pop(key)
|
||||
k, v = torch.chunk(tensor, 2, dim=0)
|
||||
state_dict[key.replace(".kv.", ".k.")] = k
|
||||
state_dict[key.replace(".kv.", ".v.")] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
old_text = "\n".join(state_dict)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
model_params = CLASSIFIERS[args.model_name]
|
||||
id2label = get_id2label_mapping(model_params["dataset"])
|
||||
|
||||
if not len(id2label) == model_params["num_labels"]:
|
||||
raise ValueError(
|
||||
f"Number of labels in id2label mapping ({len(id2label)}) does not "
|
||||
f"match number of labels in model ({model_params['num_labels']})"
|
||||
)
|
||||
|
||||
model = VJEPA2ForVideoClassification.from_pretrained(
|
||||
model_params["base_model"],
|
||||
num_labels=model_params["num_labels"],
|
||||
id2label=id2label,
|
||||
frames_per_clip=model_params["frames_per_clip"],
|
||||
)
|
||||
processor = VJEPA2VideoProcessor.from_pretrained(model_params["base_model"])
|
||||
|
||||
# load and convert classifier checkpoint
|
||||
checkpoint = torch.hub.load_state_dict_from_url(model_params["checkpoint"])
|
||||
state_dict = checkpoint["classifiers"][0]
|
||||
|
||||
state_dict_qkv_split = split_qkv(state_dict)
|
||||
key_mapping = convert_old_keys_to_new_keys(state_dict_qkv_split.keys())
|
||||
converted_state_dict2 = {key_mapping[k]: v for k, v in state_dict_qkv_split.items()}
|
||||
|
||||
result = model.load_state_dict(converted_state_dict2, strict=False)
|
||||
if result.unexpected_keys:
|
||||
raise ValueError(f"Error loading state dict: {result.unexpected_keys}")
|
||||
|
||||
if not args.skip_verification:
|
||||
# get inputs
|
||||
video_reader = get_video()
|
||||
frame_indexes = np.arange(0, 128, 128 / model_params["frames_per_clip"])
|
||||
video = video_reader.get_batch(frame_indexes).asnumpy()
|
||||
inputs = processor(video, return_tensors="pt").to(device)
|
||||
|
||||
# run model
|
||||
model.to(device).eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# compare results
|
||||
probs = torch.softmax(outputs.logits, dim=-1)
|
||||
top_prob, top_idx = probs.topk(1)
|
||||
top_prob, top_idx = top_prob.item(), top_idx.item()
|
||||
label = id2label[top_idx]
|
||||
expected_id, expected_prob, expected_label = model_params["result"]
|
||||
|
||||
if not top_idx == expected_id:
|
||||
raise ValueError(f"Expected id {expected_id} but got {top_idx}")
|
||||
if not label == expected_label:
|
||||
raise ValueError(f"Expected label {expected_label} but got {label}")
|
||||
if not np.isclose(top_prob, expected_prob, atol=1e-3):
|
||||
raise ValueError(f"Expected prob {expected_prob} but got {top_prob}")
|
||||
print("Verification passed")
|
||||
|
||||
output_dir = os.path.join(args.base_dir, args.model_name)
|
||||
model.save_pretrained(output_dir)
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
api = HfApi()
|
||||
repo_id = f"{args.repo_org}/{args.model_name}"
|
||||
if not api.repo_exists(repo_id):
|
||||
api.create_repo(repo_id, repo_type="model")
|
||||
api.upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, required=True)
|
||||
parser.add_argument("--base_dir", type=str, default="converted_models/")
|
||||
parser.add_argument("--repo_org", type=str, default="qubvel-hf")
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--skip_verification", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -19,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, dataclass
|
||||
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_vjepa2 import VJEPA2Config
|
||||
@ -536,17 +537,21 @@ class VJEPA2Encoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def apply_masks(x, masks) -> torch.Tensor:
|
||||
def apply_masks(tensor: torch.Tensor, masks: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
|
||||
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
Tensor of shape [batch_size, num_patches, feature_dim]
|
||||
masks (`List[torch.Tensor]`):
|
||||
List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
|
||||
"""
|
||||
all_x = []
|
||||
for m in masks:
|
||||
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
|
||||
all_x += [torch.gather(x, dim=1, index=mask_keep)]
|
||||
all_masked_tensors = []
|
||||
for mask in masks:
|
||||
mask = mask.to(tensor.device)
|
||||
mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1))
|
||||
all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)]
|
||||
|
||||
return torch.cat(all_x, dim=0)
|
||||
return torch.cat(all_masked_tensors, dim=0)
|
||||
|
||||
|
||||
class VJEPA2PredictorEmbeddings(nn.Module):
|
||||
@ -649,13 +654,18 @@ class VJEPA2Predictor(nn.Module):
|
||||
self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
|
||||
|
||||
def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
|
||||
# gather position masks
|
||||
argsort = argsort.to(position_masks.device)
|
||||
position_masks = torch.gather(position_masks, dim=1, index=argsort)
|
||||
hidden_states = torch.gather(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
|
||||
)
|
||||
|
||||
# gather hidden states
|
||||
argsort = argsort.to(hidden_states.device)
|
||||
hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
|
||||
hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort)
|
||||
|
||||
# gather head mask
|
||||
if head_mask is not None and head_mask[0] is not None:
|
||||
argsort = argsort.to(head_mask.device)
|
||||
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
||||
argsort_4d = (
|
||||
argsort.unsqueeze(1)
|
||||
@ -673,15 +683,14 @@ class VJEPA2Predictor(nn.Module):
|
||||
)
|
||||
head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
|
||||
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
||||
|
||||
return hidden_states, position_masks, head_mask
|
||||
|
||||
def unsort_tokens(self, hidden_states, argsort):
|
||||
argsort = argsort.to(hidden_states.device)
|
||||
reverse_argsort = torch.argsort(argsort, dim=1)
|
||||
hidden_states = torch.gather(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
|
||||
)
|
||||
reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
|
||||
hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort)
|
||||
return hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
@ -735,49 +744,304 @@ class VJEPA2Predictor(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class VJEPA2PoolerSelfAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: VJEPA2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class VJEPA2PoolerCrossAttention(nn.Module):
|
||||
"""It's different from other cross-attention layers, doesn't have output projection layer (o_proj)"""
|
||||
|
||||
# in case of modular refactoring - o_proj can be replaces with nn.Identity()
|
||||
|
||||
def __init__(self, config: VJEPA2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
queries: torch.Tensor,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, q_seq_length, embed_dim = queries.shape
|
||||
kv_seq_length = keys.shape[1]
|
||||
|
||||
queries = self.q_proj(queries)
|
||||
keys = self.k_proj(keys)
|
||||
values = self.v_proj(values)
|
||||
|
||||
queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous()
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP
|
||||
class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: VJEPA2Config):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.self_attn = VJEPA2PoolerSelfAttention(config)
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||
attention_mask (`torch.FloatTensor`):
|
||||
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: VJEPA2Config):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cross_attn = VJEPA2PoolerCrossAttention(config)
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
queries: torch.Tensor,
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
# Apply cross-attention
|
||||
residual = queries
|
||||
hidden_state = self.layer_norm1(hidden_state)
|
||||
hidden_state, *attn_weights = self.cross_attn(
|
||||
queries,
|
||||
hidden_state,
|
||||
hidden_state,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_state = residual + hidden_state
|
||||
|
||||
# Apply MLP
|
||||
residual = hidden_state
|
||||
hidden_state = self.layer_norm2(hidden_state)
|
||||
hidden_state = self.mlp(hidden_state)
|
||||
hidden_state = residual + hidden_state
|
||||
|
||||
outputs = (hidden_state,)
|
||||
if output_attentions:
|
||||
outputs += tuple(attn_weights)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class VJEPA2AttentivePooler(nn.Module):
|
||||
"""Attentive Pooler"""
|
||||
|
||||
def __init__(self, config: VJEPA2Config):
|
||||
super().__init__()
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config)
|
||||
self.self_attention_layers = nn.ModuleList(
|
||||
[VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)]
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.self_attention_layers:
|
||||
hidden_state = layer(hidden_state, attention_mask=None)[0]
|
||||
queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1)
|
||||
hidden_state = self.cross_attention_layer(queries, hidden_state)[0]
|
||||
return hidden_state.squeeze(1)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class VJEPA2PreTrainedModel(PreTrainedModel):
|
||||
config_class = VJEPA2Config
|
||||
base_model_prefix = "vjepa2"
|
||||
main_input_name = "pixel_values_videos"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["VJEPA2Layer"]
|
||||
_no_split_modules = [
|
||||
"VJEPA2Layer",
|
||||
"VJEPA2PoolerSelfAttentionLayer",
|
||||
"VJEPA2PoolerCrossAttentionLayer",
|
||||
"VJEPA2PredictorEmbeddings",
|
||||
]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(
|
||||
self,
|
||||
module: Union[
|
||||
nn.Linear,
|
||||
nn.Conv2d,
|
||||
nn.LayerNorm,
|
||||
VJEPA2Embeddings,
|
||||
VJEPA2PredictorEmbeddings,
|
||||
],
|
||||
):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32),
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
).to(module.weight.dtype)
|
||||
|
||||
init_std = self.config.initializer_range
|
||||
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
def trunc_normal_f32_(weight, std):
|
||||
data_float_32 = weight.data.to(torch.float32)
|
||||
data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std)
|
||||
weight.data = data_init.to(weight.dtype)
|
||||
|
||||
if isinstance(module, VJEPA2AttentivePooler):
|
||||
trunc_normal_f32_(module.query_tokens, std=init_std)
|
||||
for i, layer in enumerate(module.self_attention_layers, 1):
|
||||
std = init_std / (i**0.5)
|
||||
trunc_normal_f32_(layer.self_attn.out_proj.weight, std=std)
|
||||
trunc_normal_f32_(layer.mlp.fc2.weight, std=std)
|
||||
std = init_std / (len(module.self_attention_layers) + 1) ** 0.5
|
||||
trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std)
|
||||
elif isinstance(module, VJEPA2PredictorEmbeddings):
|
||||
if module.zero_init_mask_tokens:
|
||||
module.mask_tokens.data.zero_()
|
||||
else:
|
||||
trunc_normal_f32_(module.mask_tokens, std=init_std)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
||||
trunc_normal_f32_(module.weight, std=init_std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, VJEPA2PredictorEmbeddings):
|
||||
if not module.zero_init_mask_tokens:
|
||||
module.mask_token = nn.init.trunc_normal_(
|
||||
module.mask_token.to(torch.float32),
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
).to(module.mask_token.dtype)
|
||||
else:
|
||||
module.mask_tokens.data.zero_()
|
||||
|
||||
|
||||
def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
|
||||
@ -900,4 +1164,92 @@ class VJEPA2Model(VJEPA2PreTrainedModel):
|
||||
return encoder_output.last_hidden_state
|
||||
|
||||
|
||||
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"]
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
|
||||
"""
|
||||
)
|
||||
class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
|
||||
def __init__(self, config: VJEPA2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.vjepa2 = VJEPA2Model(config)
|
||||
|
||||
# Classifier head
|
||||
self.pooler = VJEPA2AttentivePooler(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.Tensor,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Union[Tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`):
|
||||
The input video pixels which is processed by VJEPA2VideoProcessor.
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification
|
||||
|
||||
>>> device = "cuda"
|
||||
|
||||
>>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
|
||||
>>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)
|
||||
|
||||
>>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB
|
||||
>>> inputs = video_processor(video, return_tensors="pt").to(device)
|
||||
|
||||
>>> # For inference
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> predicted_label = logits.argmax(-1).item()
|
||||
>>> print(model.config.id2label[predicted_label])
|
||||
|
||||
>>> # For training
|
||||
>>> labels = torch.ones(1, dtype=torch.long, device=device)
|
||||
>>> loss = model(**inputs, labels=labels).loss
|
||||
|
||||
```"""
|
||||
|
||||
outputs = self.vjepa2(
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
skip_predictor=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
pooler_output = self.pooler(last_hidden_state)
|
||||
logits = self.classifier(pooler_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config)
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"]
|
||||
|
@ -56,6 +56,7 @@ from ..models.auto.modeling_auto import (
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
)
|
||||
@ -194,6 +195,7 @@ _SPECIAL_SUPPORTED_MODELS = [
|
||||
"TrOCRDecoder",
|
||||
"PeftModelForCausalLM",
|
||||
"PeftModelForSeq2SeqLM",
|
||||
"VJEPA2ForVideoClassification",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
# XLNetForQuestionAnswering,
|
||||
]
|
||||
@ -904,6 +906,7 @@ class HFTracer(Tracer):
|
||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
|
||||
]:
|
||||
|
@ -40,7 +40,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import VJEPA2Model
|
||||
from transformers import VJEPA2ForVideoClassification, VJEPA2Model
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -153,7 +153,7 @@ class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
test_torch_exportable = True
|
||||
|
||||
all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
|
||||
all_model_classes = (VJEPA2Model, VJEPA2ForVideoClassification) if is_torch_available() else ()
|
||||
|
||||
fx_compatible = True
|
||||
|
||||
@ -267,7 +267,7 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase):
|
||||
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
|
||||
device=torch_device,
|
||||
)
|
||||
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=8e-2, atol=8e-2)
|
||||
|
||||
@slow
|
||||
def test_inference_video(self):
|
||||
@ -343,3 +343,22 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase):
|
||||
# verify the last hidden states
|
||||
expected_shape = torch.Size((1, num_masks, 1024))
|
||||
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_video_classification(self):
|
||||
checkpoint = "facebook/vjepa2-vitl-fpc16-256-ssv2"
|
||||
|
||||
model = VJEPA2ForVideoClassification.from_pretrained(checkpoint).to(torch_device)
|
||||
video_processor = AutoVideoProcessor.from_pretrained(checkpoint)
|
||||
|
||||
sample_video = np.ones((16, 3, 256, 256))
|
||||
inputs = video_processor(sample_video, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertEqual(outputs.logits.shape, (1, 174))
|
||||
|
||||
expected_logits = torch.tensor([0.8814, -0.1195, -0.6389], device=torch_device)
|
||||
resulted_logits = outputs.logits[0, 100:103]
|
||||
torch.testing.assert_close(resulted_logits, expected_logits, rtol=1e-2, atol=1e-2)
|
||||
|
Loading…
Reference in New Issue
Block a user