Add I-JEPA (#33125)

* first draft

* add IJepaEmbeddings class

* fix copy-from for IJepa model

* add weight conversion script

* update attention class names in IJepa model

* style changes

* Add push_to_hub option to convert_ijepa_checkpoint function

* add initial tests for I-JEPA

* minor style changes to conversion script

* make fixup related

* rename conversion script

* Add I-JEPA to sdpa docs

* minor fixes

* adjust conversion script

* update conversion script

* adjust sdpa docs

* [run_slow] ijepa

* [run-slow] ijepa

* [run-slow] ijepa

* [run-slow] ijepa

* [run-slow] ijepa

* [run-slow] ijepa

* formatting issues

* adjust modeling to modular code

* add IJepaModel to objects to ignore in docstring checks

* [run-slow] ijepa

* fix formatting issues

* add usage instruction snippet to docs

* change pos encoding, add checkpoint for doc

* add verify logits for all models

* [run-slow] ijepa

* update docs to include image feature extraction instructions

* remove pooling layer from IJepaModel in image classification class

* [run-slow] ijepa

* remove pooling layer from IJepaModel constructor

* update docs

* [run-slow] ijepa

* [run-slow] ijepa

* small changes

* [run-slow] ijepa

* style adjustments

* update copyright in init file

* adjust modular ijepa

* [run-slow] ijepa
This commit is contained in:
João Marcelo 2024-12-05 16:14:46 +01:00 committed by GitHub
parent 95a855e212
commit 50189e36a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1907 additions and 2 deletions

View File

@ -657,6 +657,8 @@
title: GLPN
- local: model_doc/hiera
title: Hiera
- local: model_doc/ijepa
title: I-JEPA
- local: model_doc/imagegpt
title: ImageGPT
- local: model_doc/levit

View File

@ -168,6 +168,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
| [I-JEPA](model_doc/ijepa) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,78 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# I-JEPA
## Overview
The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations.
The abstract from the paper is the following:
This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.
This model was contributed by [jmtzt](https://huggingface.co/jmtzt).
The original code can be found [here](https://github.com/facebookresearch/ijepa).
## How to use
Here is how to use this model for image feature extraction:
```python
import requests
import torch
from PIL import Image
from torch.nn.functional import cosine_similarity
from transformers import AutoModel, AutoProcessor
url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
image_1 = Image.open(requests.get(url_1, stream=True).raw)
image_2 = Image.open(requests.get(url_2, stream=True).raw)
model_id = "jmtzt/ijepa_vith14_1k"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
@torch.no_grad()
def infer(image):
inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1)
embed_1 = infer(image_1)
embed_2 = infer(image_2)
similarity = cosine_similarity(embed_1, embed_2)
print(similarity)
```
## IJepaConfig
[[autodoc]] IJepaConfig
## IJepaModel
[[autodoc]] IJepaModel
- forward
## IJepaForImageClassification
[[autodoc]] IJepaForImageClassification
- forward

View File

@ -235,6 +235,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
@ -242,7 +243,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [I-JEPA](https://huggingface.co/docs/transformers/model_doc/ijepa#transformers.IJepaModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)

View File

@ -485,6 +485,7 @@ _import_structure = {
"models.idefics": ["IdeficsConfig"],
"models.idefics2": ["Idefics2Config"],
"models.idefics3": ["Idefics3Config"],
"models.ijepa": ["IJepaConfig"],
"models.imagegpt": ["ImageGPTConfig"],
"models.informer": ["InformerConfig"],
"models.instructblip": [
@ -2462,6 +2463,13 @@ else:
"Idefics3Processor",
]
)
_import_structure["models.ijepa"].extend(
[
"IJepaForImageClassification",
"IJepaModel",
"IJepaPreTrainedModel",
]
)
_import_structure["models.imagegpt"].extend(
[
"ImageGPTForCausalImageModeling",
@ -5368,6 +5376,7 @@ if TYPE_CHECKING:
)
from .models.idefics2 import Idefics2Config
from .models.idefics3 import Idefics3Config
from .models.ijepa import IJepaConfig
from .models.imagegpt import ImageGPTConfig
from .models.informer import InformerConfig
from .models.instructblip import (
@ -7181,6 +7190,11 @@ if TYPE_CHECKING:
Idefics3PreTrainedModel,
Idefics3Processor,
)
from .models.ijepa import (
IJepaForImageClassification,
IJepaModel,
IJepaPreTrainedModel,
)
from .models.imagegpt import (
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,

View File

@ -117,6 +117,7 @@ from . import (
idefics,
idefics2,
idefics3,
ijepa,
imagegpt,
informer,
instructblip,

View File

@ -135,6 +135,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("idefics", "IdeficsConfig"),
("idefics2", "Idefics2Config"),
("idefics3", "Idefics3Config"),
("ijepa", "IJepaConfig"),
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
@ -440,6 +441,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("idefics", "IDEFICS"),
("idefics2", "Idefics2"),
("idefics3", "Idefics3"),
("ijepa", "I-JEPA"),
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),

View File

@ -90,6 +90,7 @@ else:
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor",)),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
@ -433,7 +434,9 @@ class AutoImageProcessor:
if image_processor_class is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
# It could be in `config.image_processor_type``
image_processor_class = getattr(config, "image_processor_type", None)

View File

@ -132,6 +132,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("idefics", "IdeficsModel"),
("idefics2", "Idefics2Model"),
("idefics3", "Idefics3Model"),
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
@ -578,6 +579,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("focalnet", "FocalNetModel"),
("glpn", "GLPNModel"),
("hiera", "HieraModel"),
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("mllama", "MllamaVisionModel"),
@ -655,6 +657,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"),
("hiera", "HieraForImageClassification"),
("ijepa", "IJepaForImageClassification"),
("imagegpt", "ImageGPTForImageClassification"),
(
"levit",

View File

@ -0,0 +1,55 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {"configuration_ijepa": ["IJepaConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_ijepa"] = [
"IJepaForImageClassification",
"IJepaModel",
"IJepaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_ijepa import IJepaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_ijepa import (
IJepaForImageClassification,
IJepaModel,
IJepaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,108 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""I-JEPA model configuration"""
from ...configuration_utils import PretrainedConfig
class IJepaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the I-JEPA
[google/ijepa-base-patch16-224](https://huggingface.co/google/ijepa-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
Example:
```python
>>> from transformers import IJepaConfig, IJepaModel
>>> # Initializing a IJEPA ijepa-base-patch16-224 style configuration
>>> configuration = IJepaConfig()
>>> # Initializing a model (with random weights) from the ijepa-base-patch16-224 style configuration
>>> model = IJepaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "ijepa"
def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-12,
image_size=224,
patch_size=16,
num_channels=3,
qkv_bias=True,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias

View File

@ -0,0 +1,267 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert IJEPA checkpoints from the original repository.
URL: https://github.com/facebookresearch/ijepa
"""
import argparse
import gc
import re
from pathlib import Path
import requests
import torch
from PIL import Image
from transformers import (
IJepaConfig,
IJepaModel,
ViTImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Projection layer + position embeddings
r"pos_embed": r"embeddings.position_embeddings",
r"patch_embed.proj.weight": r"embeddings.patch_embeddings.projection.weight",
r"patch_embed.proj.bias": r"embeddings.patch_embeddings.projection.bias",
# Encoder layers: Layernorms, Attention, Feedforward layers
r"blocks.(\d+).norm1.weight": r"encoder.layer.\1.layernorm_before.weight",
r"blocks.(\d+).norm1.bias": r"encoder.layer.\1.layernorm_before.bias",
r"blocks.(\d+).attn.proj.weight": r"encoder.layer.\1.attention.output.dense.weight",
r"blocks.(\d+).attn.proj.bias": r"encoder.layer.\1.attention.output.dense.bias",
r"blocks.(\d+).norm2.weight": r"encoder.layer.\1.layernorm_after.weight",
r"blocks.(\d+).norm2.bias": r"encoder.layer.\1.layernorm_after.bias",
r"blocks.(\d+).mlp.fc1.weight": r"encoder.layer.\1.intermediate.dense.weight",
r"blocks.(\d+).mlp.fc1.bias": r"encoder.layer.\1.intermediate.dense.bias",
r"blocks.(\d+).mlp.fc2.weight": r"encoder.layer.\1.output.dense.weight",
r"blocks.(\d+).mlp.fc2.bias": r"encoder.layer.\1.output.dense.bias",
# Layernorm + pooler
r"norm.weight": r"layernorm.weight",
r"norm.bias": r"layernorm.bias",
}
# fmt: on
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
"""
Converts old keys to new keys using the mapping and dynamically removes the 'ijepa.' prefix if necessary.
Args:
state_dict_keys (dict): The keys from the state_dict to convert.
Returns:
dict: A mapping from old keys to new keys.
"""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
# Apply regex-based mapping
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # Skip the key
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config):
for i in range(config.num_hidden_layers):
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def get_ijepa_config(model_name):
patch_size = int(model_name.split("_")[1][4:])
config = IJepaConfig(patch_size=patch_size)
if "vith" in model_name:
config.hidden_size = 1280
config.num_hidden_layers = 32
config.num_attention_heads = 16
config.layer_norm_eps = 1e-6
config.mlp_ratio = 4
config.intermediate_size = 5120
if model_name == "ijepa_vith16_1k":
config.image_size = 448
elif "vitg" in model_name:
config.hidden_size = 1408
config.num_hidden_layers = 40
config.num_attention_heads = 16
config.layer_norm_eps = 1e-6
config.mlp_ratio = 48 / 11
config.intermediate_size = 6144
else:
raise ValueError("Model not supported, only supports huge and giant models.")
return config
@torch.no_grad()
def write_model(model_name, output_dir, safe_serialization, push_to_hub, verify_logits):
"""
Copy/paste/tweak model's weights to our IJEPA structure.
"""
# define default IJEPA configuration
config = get_ijepa_config(model_name)
checkpoint_mapping = {
"ijepa_vith14_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar",
"ijepa_vith14_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar",
"ijepa_vith16_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar",
"ijepa_vitg16_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar",
}
# Load original checkpoint
checkpoint_url = checkpoint_mapping[model_name]
original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["encoder"]
original_state_dict = {k.replace("module.", ""): v for k, v in original_state_dict.items()}
# Rename keys
state_dict = original_state_dict.copy()
new_keys = convert_old_keys_to_new_keys(state_dict.keys())
for old_key, new_key in new_keys.items():
rename_key(state_dict, old_key, new_key)
read_in_q_k_v(state_dict, config)
# load HuggingFace model
model = IJepaModel(config, add_pooling_layer=False).eval()
model.load_state_dict(state_dict)
size = {"height": config.image_size, "width": config.image_size}
image_processor = ViTImageProcessor(size=size)
if verify_logits:
# Check outputs on an image, prepared by ViTImageProcessor
encoding = image_processor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
with torch.no_grad():
outputs = model(pixel_values)
expected_slices = {
"ijepa_vith14_1k": torch.Tensor(
[[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
),
"ijepa_vith14_22k": torch.Tensor(
[[0.0358, -0.0045, -0.2154], [0.0418, -0.0246, 0.0108], [0.2529, -0.0345, -0.0246]]
),
"ijepa_vith16_1k": torch.Tensor(
[[0.5145, -0.1259, 0.0615], [0.1132, 0.0028, -0.0496], [1.1586, -0.0056, -0.0387]]
),
"ijepa_vitg16_22k": torch.Tensor(
[[0.0512, -0.0510, -0.0649], [0.1972, 0.0380, -0.0790], [0.1667, -0.0834, -0.1240]]
),
}
assert torch.allclose(
expected_slices[model_name],
outputs.last_hidden_state[0, :3, :3],
atol=1e-4,
)
if output_dir:
Path(output_dir).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {output_dir}")
image_processor.save_pretrained(output_dir, safe_serialization=safe_serialization)
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
if push_to_hub:
image_processor.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization)
model.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization)
if output_dir:
del model, state_dict
gc.collect()
print("Reloading the model to check if it's saved correctly.")
IJepaModel.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.")
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="ijepa_vith14_1k",
type=str,
choices=[
"ijepa_vith14_1k",
"ijepa_vith14_22k",
"ijepa_vith16_1k",
"ijepa_vitg16_22k",
],
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the model to the 🤗 Hub.",
)
parser.add_argument(
"--verify_logits", action="store_false", help="Whether or not to verify logits after conversion."
)
parser.set_defaults()
args = parser.parse_args()
write_model(args.model_name, args.output_dir, args.safe_serialization, args.push_to_hub, args.verify_logits)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,751 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_ijepa.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
torch_int,
)
from .configuration_ijepa import IJepaConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k"
# General docstring
_CONFIG_FOR_DOC = "IJepaConfig"
class IJepaPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class IJepaEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = IJepaPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
patch_pos_embed = self.position_embeddings
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class IJepaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = IJepaConfig
base_model_prefix = "ijepa"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, IJepaEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
class IJepaSelfAttention(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class IJepaSdpaSelfAttention(IJepaSelfAttention):
def __init__(self, config: IJepaConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
class IJepaSelfOutput(nn.Module):
"""
The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class IJepaAttention(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
self.attention = IJepaSelfAttention(config)
self.output = IJepaSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class IJepaSdpaAttention(IJepaAttention):
def __init__(self, config: IJepaConfig) -> None:
super().__init__(config)
self.attention = IJepaSdpaSelfAttention(config)
class IJepaIntermediate(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class IJepaOutput(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
IJEPA_ATTENTION_CLASSES = {
"eager": IJepaAttention,
"sdpa": IJepaSdpaAttention,
}
class IJepaLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = IJepaIntermediate(config)
self.output = IJepaOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in IJepa, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
# in IJepa, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class IJepaEncoder(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class IJepaPooler(nn.Module):
def __init__(self, config: IJepaConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
IJEPA_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`IJepaImageProcessor.__call__`]
for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
IJEPA_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`IJepaConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.",
IJEPA_START_DOCSTRING,
)
class IJepaModel(IJepaPreTrainedModel):
def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
super().__init__(config)
self.config = config
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = IJepaEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = IJepaPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> IJepaPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
@add_start_docstrings(
"""
IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
IJEPA_START_DOCSTRING,
)
class IJepaForImageClassification(IJepaPreTrainedModel):
def __init__(self, config: IJepaConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.ijepa = IJepaModel(config, add_pooling_layer=False)
# Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.ijepa(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output.mean(dim=1))
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"]

View File

@ -0,0 +1,255 @@
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.ijepa.configuration_ijepa import IJepaConfig
from ...modeling_outputs import ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
torch_int,
)
from ..vit.modeling_vit import (
ViTEmbeddings,
ViTForImageClassification,
ViTModel,
)
_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k"
class IJepaEmbeddings(ViTEmbeddings):
def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
super().__init__(config, use_mask_token)
# Remove cls_token from IJepaEmbeddings, as it is not used in the model
del self.cls_token
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
patch_pos_embed = self.position_embeddings
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class IJepaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = IJepaConfig
base_model_prefix = "ijepa"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, IJepaEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]
IJEPA_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`IJepaConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.",
IJEPA_START_DOCSTRING,
)
class IJepaModel(IJepaPreTrainedModel, ViTModel):
def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
super().__init__(config)
self.config = config
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_vith14_1k"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
@add_start_docstrings(
"""
IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
IJEPA_START_DOCSTRING,
)
class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification):
def __init__(self, config: IJepaConfig):
super().__init__(config)
self.ijepa = IJepaModel(config, add_pooling_layer=False)
self.post_init()
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.ijepa(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output.mean(dim=1))
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"IJepaPreTrainedModel",
"IJepaModel",
"IJepaForImageClassification",
]

View File

@ -4978,6 +4978,27 @@ class Idefics3Processor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class IJepaForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class IJepaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class IJepaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ImageGPTForCausalImageModeling(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -140,6 +140,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"gptj",
"hiera",
"hubert",
"ijepa",
"layoutlm",
"llama",
"cohere",

View File

View File

@ -0,0 +1,341 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch IJEPA model."""
import unittest
from transformers import IJepaConfig
from transformers.testing_utils import (
require_accelerate,
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
cached_property,
is_torch_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import IJepaForImageClassification, IJepaModel
if is_vision_available():
from PIL import Image
from transformers import ViTImageProcessor
class IJepaModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in IJEPA, the seq length equals the number of patches (we don't add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.num_channels,
self.image_size,
self.image_size,
]
)
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return IJepaConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
model = IJepaModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.seq_length, self.hidden_size),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = IJepaForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(
result.logits.shape,
(self.batch_size, self.type_sequence_label_size),
)
# test greyscale images
config.num_channels = 1
model = IJepaForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape,
(self.batch_size, self.type_sequence_label_size),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as IJEPA does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(
IJepaModel,
IJepaForImageClassification,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"image-feature-extraction": IJepaModel, "image-classification": IJepaForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = IJepaModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=IJepaConfig,
has_text_modality=False,
hidden_size=37,
)
@unittest.skip(
"Since `torch==2.3+cu121`, although this test passes, many subsequent tests have `CUDA error: misaligned address`."
"If `nvidia-xxx-cu118` are also installed, no failure (even with `torch==2.3+cu121`)."
)
def test_multi_gpu_data_parallel_forward(self):
super().test_multi_gpu_data_parallel_forward()
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="IJEPA does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model_name = "jmtzt/ijepa_vith14_1k"
model = IJepaModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
class IJepaModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("jmtzt/ijepa_vith14_1k") if is_vision_available() else None
@slow
def test_inference_no_head(self):
model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the last hidden state
expected_shape = torch.Size((1, 256, 1280))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.Tensor(
[[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@slow
@require_accelerate
@require_torch_accelerator
@require_torch_fp16
def test_inference_fp16(self):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = IJepaModel.from_pretrained(
"jmtzt/ijepa_vith14_1k",
torch_dtype=torch.float16,
device_map="auto",
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass to make sure inference works in fp16
with torch.no_grad():
_ = model(pixel_values)
@slow
def test_inference_interpolate_pos_encoding(self):
# I-JEPA, similar to ViT models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 256, 1280))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))

View File

@ -331,6 +331,7 @@ OBJECTS_TO_IGNORE = [
"IBertModel",
"IdeficsConfig",
"IdeficsProcessor",
"IJepaModel",
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageGPTConfig",