This commit is contained in:
Yaswanth Gali 2025-07-03 02:51:14 +09:00 committed by GitHub
commit 73df42d8f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 3022 additions and 2 deletions

View File

@ -693,6 +693,8 @@
title: Zamba2
title: Text models
- sections:
- local: model_doc/aimv2
title: Aimv2
- local: model_doc/beit
title: BEiT
- local: model_doc/bit

View File

@ -0,0 +1,104 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# AIMv2
## Overview
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
The abstract from the paper is the following:
*We introduce a novel method for pre-training of large-scale vision encoders. Building on recent advancements in autoregressive pre-training of vision models, we extend this framework to a multimodal setting, i.e., images and text. In this paper, we present AIMV2, a family of generalist vision encoders characterized by a straightforward pre-training process, scalability, and remarkable performance across a range of downstream tasks. This is achieved by pairing the vision encoder with a multimodal decoder that autoregressively generates raw image patches and text tokens. Our encoders excel not only in multimodal evaluations but also in vision benchmarks such as localization, grounding, and classification. Notably, our AIMV2-3B encoder achieves 89.5% accuracy on ImageNet-1k with a frozen trunk. Furthermore, AIMV2 consistently outperforms state-of-the-art contrastive models (e.g., CLIP, SigLIP) in multimodal image understanding across diverse settings.*
This model was contributed by [Yaswanth Gali](https://huggingface.co/yaswanthgali).
The original code can be found [here](https://github.com/apple/ml-aim).
## Usage Example
Here is an example of Image Feature Extraction using specific checkpoints on resized images and native resolution images:
```python
import requests
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native")
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
```
Here is an example of a checkpoint performing zero-shot classification:
```python
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModel
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit")
inputs = processor(
images=image,
text=text,
add_special_tokens=True,
truncation=True,
padding=True,
return_tensors="pt",
)
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=-1)
```
## Aimv2Config
[[autodoc]] Aimv2Config
## Aimv2TextConfig
[[autodoc]] Aimv2TextConfig
## Aimv2VisionConfig
[[autodoc]] Aimv2VisionConfig
## Aimv2Model
[[autodoc]] Aimv2Model
- forward
## Aimv2VisionModel
[[autodoc]] Aimv2VisionModel
- forward
## Aimv2TextModel
[[autodoc]] Aimv2TextModel
- forward
</pt>
<tf>

View File

@ -18,6 +18,7 @@ from ..utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .aimv2 import *
from .albert import *
from .align import *
from .altclip import *

View File

@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_aimv2 import *
from .modeling_aimv2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,296 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_aimv2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class Aimv2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a
AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2
[apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 2816):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries, keys and values.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the Linear layers or Not.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not.
is_native (`str`, *optional*, defaults to `False`):
Whether to use ckpt trained for image native resolution or not.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration
>>> configuration = Aimv2VisionConfig()
>>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration
>>> model = Aimv2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "aimv2_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 2816,
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 14,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
qkv_bias: bool = False,
mlp_bias: bool = False,
hidden_act: str = "silu",
initializer_range: float = 0.02,
use_head: bool = True,
is_native: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.use_head = use_head
self.initializer_range = initializer_range
self.mlp_bias = mlp_bias
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.is_native = is_native
class Aimv2TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a
AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 49408):
Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`Aimv2Model`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 2048):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries, keys and values.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the Linear layers or Not.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
max_position_embeddings (`int`, *optional*, defaults to 77):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the for initializing all weight matrices.
"""
model_type = "aimv2_text_model"
base_config_key = "text_config"
def __init__(
self,
vocab_size: int = 49408,
hidden_size: int = 768,
intermediate_size: int = 2048,
num_hidden_layers: int = 12,
num_attention_heads: int = 6,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
qkv_bias: bool = False,
mlp_bias: bool = False,
hidden_act: str = "silu",
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: int = 49407,
max_position_embeddings: int = 77,
initializer_range: bool = 0.02,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.mlp_bias = mlp_bias
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
class Aimv2Config(PretrainedConfig):
r"""
[`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to
instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs.
Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Aimv2TextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Aimv2VisionConfig`].
projection_dim (`int`, *optional*, defaults to 512):
Dimensionality of text and vision projection layers.
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
The initial value of the *logit_scale* parameter.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import Aimv2Config, Aimv2Model
>>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration
>>> configuration = Aimv2Config()
>>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration
>>> model = Aimv2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig
>>> from transformers import Aimv2TextConfig, Aimv2VisionConfig
>>> # Initializing a AIMv2Text and AIMv2Vision configuration
>>> config_text = Aimv2TextConfig()
>>> config_vision = Aimv2VisionConfig()
>>> config = Aimv2Config(text_config=config_text, vision_config=config_vision)
```"""
model_type = "aimv2"
sub_configs = {"text_config": Aimv2TextConfig, "vision_config": Aimv2VisionConfig}
def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
):
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `Aimv2TextConfig` with default values.")
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. initializing the `Aimv2VisionConfig` with default values.")
self.text_config = Aimv2TextConfig(**text_config)
self.vision_config = Aimv2VisionConfig(**vision_config)
self.projection_dim = projection_dim
self.logit_scale_init_value = logit_scale_init_value
self.max_logit_scale = 100.0
@classmethod
def from_text_vision_configs(cls, text_config: Aimv2TextConfig, vision_config: Aimv2VisionConfig, **kwargs):
r"""
Instantiate a [`Aimv2Config`] (or a derived class) from aimv2 text model configuration and aimv2 vision
model configuration.
Returns:
[`Aimv2Config`]: An instance of a configuration object
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
__all__ = ["Aimv2Config", "Aimv2VisionConfig", "Aimv2TextConfig"]

View File

@ -0,0 +1,269 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
from typing import Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
Aimv2Config,
Aimv2Model,
Aimv2VisionConfig,
Aimv2VisionModel,
AutoImageProcessor,
AutoProcessor,
)
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL = {
# Embeddings
r"preprocessor.patchifier.proj": r"embeddings.patch_embed",
r"preprocessor.pos_embed": r"embeddings.position_embedding.weight",
r"preprocessor.patchifier.norm.weight": r"embeddings.rms_norm.weight",
# Encoder Layers
r"trunk.blocks.(\d+).attn.qkv": r"encoder.layers.\1.attention.qkv",
r"trunk.blocks.(\d+).attn.proj": r"encoder.layers.\1.attention.out_proj",
r"trunk.blocks.(\d+).mlp.fc1": r"encoder.layers.\1.ffn.gate_proj",
r"trunk.blocks.(\d+).mlp.fc2": r"encoder.layers.\1.ffn.down_proj",
r"trunk.blocks.(\d+).mlp.fc3": r"encoder.layers.\1.ffn.up_proj",
# Normalization Layers
r"trunk.blocks.(\d+).norm_1": r"encoder.layers.\1.rms_norm1",
r"trunk.blocks.(\d+).norm_2": r"encoder.layers.\1.rms_norm2",
# Final Norm
r"trunk.post_trunk_norm": r"rms_norm",
}
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Vision Embeddings
r"image_encoder.preprocessor.patchifier.proj": r"vision_model.embeddings.patch_embed",
r"image_encoder.preprocessor.pos_embed": r"vision_model.embeddings.position_embedding.weight",
r"image_encoder.preprocessor.patchifier.norm.weight": r"vision_model.embeddings.rms_norm.weight",
# Vision Encoder Layers
r"image_encoder.trunk.blocks.(\d+).attn.qkv": r"vision_model.encoder.layers.\1.attention.qkv",
r"image_encoder.trunk.blocks.(\d+).attn.proj": r"vision_model.encoder.layers.\1.attention.out_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc1": r"vision_model.encoder.layers.\1.ffn.gate_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc2": r"vision_model.encoder.layers.\1.ffn.down_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc3": r"vision_model.encoder.layers.\1.ffn.up_proj",
# Normalization Layers
r"image_encoder.trunk.blocks.(\d+).norm_1": r"vision_model.encoder.layers.\1.rms_norm1",
r"image_encoder.trunk.blocks.(\d+).norm_2": r"vision_model.encoder.layers.\1.rms_norm2",
r"image_encoder.trunk.post_trunk_norm": r"vision_model.rms_norm",
r"image_projector": r"visual_projection",
# Vision Head
r"image_encoder.head.cls_token": r"vision_model.head.cls_token",
r"image_encoder.head.k": r"vision_model.head.k_proj",
r"image_encoder.head.v": r"vision_model.head.v_proj",
r"image_encoder.head.linear": r"vision_model.head.output_proj",
# Text Embeddings
r"text_encoder.preprocessor.text_embedding.weight": r"text_model.embeddings.token_embedding.weight",
r"text_encoder.preprocessor.positional_embedding": r"text_model.embeddings.position_embedding.weight",
# Text Encoder Layers
r"text_encoder.trunk.blocks.(\d+).attn.qkv": r"text_model.encoder.layers.\1.attention.qkv",
r"text_encoder.trunk.blocks.(\d+).attn.proj": r"text_model.encoder.layers.\1.attention.out_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc1": r"text_model.encoder.layers.\1.ffn.gate_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc2": r"text_model.encoder.layers.\1.ffn.down_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc3": r"text_model.encoder.layers.\1.ffn.up_proj",
# Text Normalization Layers
r"text_encoder.trunk.blocks.(\d+).norm_1": r"text_model.encoder.layers.\1.rms_norm1",
r"text_encoder.trunk.blocks.(\d+).norm_2": r"text_model.encoder.layers.\1.rms_norm2",
r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
r"text_projector": r"text_projection",
r"log_logit_scale": r"logit_scale",
}
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
# Download only the model.safetensors file
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["model.safetensors"],
)
original_state_dict = {}
safetensor_path = f"{directory_path}/model.safetensors"
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_old_keys_to_new_keys(state_dict_keys: dict, ORIGINAL_TO_CONVERTED_KEY_MAPPING: dict):
"""Converts state dict keys from the old format to the new format."""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
def split_qkv_tensor(key, tensor):
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
new_keys = ["q_proj", "k_proj", "v_proj"]
split_size = tensor.shape[0] // 3
split_tensors = torch.split(tensor, split_size, dim=0)
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
def get_model_config_mapping(model_id: str):
"""Determines the correct model, config, and key mappings based on the checkpoint name."""
if model_id == "apple/aimv2-large-patch14-224-lit":
return Aimv2Model, Aimv2Config, ORIGINAL_TO_CONVERTED_KEY_MAPPING
else:
return Aimv2VisionModel, Aimv2VisionConfig, ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL
def write_model(
hf_repo_id: str,
output_dir: str,
safe_serialization: bool = True,
):
"""
Converts a model checkpoint to Hugging Face format and saves it.
Args:
hf_repo_id (str): The Hugging Face repo ID to load from.
output_dir (str): The directory to save the converted model.
safe_serialization (bool): Whether to use safe serialization.
Returns:
model: The reloaded Hugging Face model.
"""
os.makedirs(output_dir, exist_ok=True)
# Get the appropriate model, config, and key mapping
model_class, config_class, key_mapping = get_model_config_mapping(hf_repo_id)
# Load config and original state dict
config = config_class.from_pretrained(hf_repo_id)
# Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
config.use_head = False
if hf_repo_id == "apple/aimv2-large-patch14-native":
config.is_native = True
original_state_dict = load_original_state_dict(hf_repo_id)
print("Converting model...")
state_dict = {}
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
all_keys = list(original_state_dict.keys())
for key in all_keys:
value = original_state_dict[key]
new_key = result.pop(key)
if "qkv" in new_key:
qkv_state_dict = split_qkv_tensor(new_key, value)
state_dict.update(qkv_state_dict)
else:
state_dict[new_key] = value
# Check if position embeddings exist before squeezing
if new_key.endswith("position_embedding.weight"):
state_dict[new_key] = value.squeeze(0)
print(f"Loading the checkpoint in a {model_class.__name__}.")
model = model_class(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
model = model_class.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.")
return model
def write_image_processor(hf_repo_id: str, output_dir: str):
if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
image_processor = AutoProcessor.from_pretrained(hf_repo_id, use_fast=True)
else:
image_processor = AutoImageProcessor.from_pretrained(hf_repo_id, use_fast=True)
image_processor.save_pretrained(output_dir)
return image_processor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_repo_id",
default="apple/aimv2-large-patch14-224",
help="Location of official weights from apple on HF",
)
parser.add_argument(
"--output_dir",
default="aimv2_model",
help="Location to write the converted model and processor",
)
parser.add_argument(
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--push_to_hub",
action=argparse.BooleanOptionalAction,
help="Whether or not to push the converted model to the huggingface hub.",
)
parser.add_argument(
"--hub_repo_id",
default=None,
help="Huggingface hub repo to write the converted model and processor",
)
args = parser.parse_args()
model = write_model(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)
image_processor = write_image_processor(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
)
if args.push_to_hub:
print("Pushing to hub...")
model.push_to_hub(args.hub_repo_id)
image_processor.push_to_hub(args.hub_repo_id)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,834 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_aimv2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple
from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
@dataclass
@auto_docstring
class Aimv2Output(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`Aimv2TextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`Aimv2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`Aimv2TextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`Aimv2VisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
@use_kernel_forward_from_hub("RMSNorm")
class Aimv2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Aimv2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Aimv2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Aimv2VisionEmbeddings(nn.Module):
def __init__(self, config: Aimv2VisionConfig):
super().__init__()
self.config = config
self.patch_size = config.patch_size
self.patch_embed = nn.Conv2d(
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
)
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2
if not self.config.is_native:
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
@staticmethod
def build_2d_sincos_position_embedding(
height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
) -> torch.Tensor:
grid_w = torch.arange(int(width), dtype=dtype, device=device)
grid_h = torch.arange(int(height), dtype=dtype, device=device)
grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
omega = 1.0 / (temperature**omega)
out_h = grid_h.flatten()[..., None] @ omega[None, :]
out_w = grid_w.flatten()[..., None] @ omega[None, :]
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
_, _, height, width = pixel_values.size()
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
hidden_states = self.rms_norm(hidden_states)
if self.config.is_native:
pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size,
width // self.patch_size,
embed_dim=self.config.hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
pos_embed = self.position_embedding(self.position_ids)
hidden_states = hidden_states + pos_embed
return hidden_states
class Aimv2TextEmbeddings(nn.Module):
def __init__(self, config: Aimv2TextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Aimv2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Aimv2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Aimv2VisionConfig):
super().__init__()
self.attention = Aimv2Attention(config)
self.ffn = Aimv2MLP(config)
self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states = self.rms_norm1(hidden_states)
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)
hidden_states = hidden_states + attn_output
norm_hidden_states = self.rms_norm2(hidden_states)
mlp_output = self.ffn(norm_hidden_states)
hidden_states = hidden_states + mlp_output
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
class Aimv2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Aimv2EncoderLayer`].
Args:
config: Aimv2Config
"""
def __init__(self, config: Aimv2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
@can_return_tuple
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
class Aimv2AttentionPoolingHead(nn.Module):
def __init__(self, config: Aimv2VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, hidden_dim = hidden_states.shape
cls_token = self.cls_token.expand(batch_size, -1, -1)
key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(query, key, value)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
attn_output = attn_output.mean(dim=1)
output = self.output_proj(attn_output)
return output
@auto_docstring
class Aimv2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models. The model is only intended for inference and doesn't support finetuning.
"""
config_class = Aimv2Config
base_model_prefix = "aimv2"
supports_gradient_checkpointing = True
_no_split_modules = [
"Aimv2EncoderLayer",
"Aimv2AttentionPoolingHead",
"Aimv2VisionEmbeddings",
"Aimv2TextEmbeddings",
]
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
def _init_weights(self, module):
std = (
self.config.vision_config.initializer_range
if hasattr(self.config, "vision_config")
else self.config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, Aimv2RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, Aimv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=std)
@auto_docstring(
custom_intro="""
The Vision model from AIMv2 without any head or projection on top.
"""
)
class Aimv2VisionModel(Aimv2PreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Aimv2VisionConfig):
super().__init__(config)
self.config = config
self.embeddings = Aimv2VisionEmbeddings(config)
self.encoder = Aimv2Encoder(config)
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.use_head = config.use_head
if self.use_head:
self.head = Aimv2AttentionPoolingHead(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embed
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Siglip2VisionModel
>>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.rms_norm(last_hidden_state)
pooler_output = self.head(last_hidden_state) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@auto_docstring(
custom_intro="""
The text model from AIMv2 without any head or projection on top.
"""
)
class Aimv2TextModel(Aimv2PreTrainedModel):
main_input_name = "input_ids"
def __init__(self, config: Aimv2TextConfig):
super().__init__(config)
self.config = config
self.embeddings = Aimv2TextEmbeddings(config)
self.encoder = Aimv2Encoder(config)
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.eos_token_id = config.eos_token_id
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.token_embedding
def set_input_embeddings(self, value):
self.embeddings.token_embedding = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(input_ids)
_, seq_len, _ = hidden_states.shape
cache_position = torch.arange(seq_len, device=hidden_states.device)
if attention_mask is not None:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.rms_norm(last_hidden_state)
# Get pooled output
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
"""
This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
"""
square_tensor = torch.pow(tensor, 2)
sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
normed_tensor = torch.pow(sum_tensor, 0.5)
return normed_tensor
@auto_docstring
class Aimv2Model(Aimv2PreTrainedModel):
config_class = Aimv2Config
_no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"]
def __init__(self, config: Aimv2Config):
super().__init__(config)
self.projection_dim = config.projection_dim
self.vision_embed_dim = config.vision_config.hidden_size
self.text_embed_dim = config.text_config.hidden_size
self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
self.text_model = Aimv2TextModel._from_config(config.text_config)
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.max_log_logit_scale = math.log(config.max_logit_scale)
self.post_init()
@auto_docstring
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`Aimv2TextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, Aimv2Model
>>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/aimv2-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs)
```"""
# Use AIMV2 model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = text_outputs.pooler_output
text_features = self.text_projection(pooled_output)
return text_features
@auto_docstring
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`Aimv2VisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Aimv2Model
>>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/aimv2-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs)
```"""
# Use AIMV2 model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooled_output = vision_outputs.pooler_output
image_features = self.visual_projection(pooled_output)
return image_features
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Aimv2Output:
r"""
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Aimv2Model
>>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
image_embeds = vision_outputs.pooler_output
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs.pooler_output
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds)
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
logits_per_image = logits_per_text.t()
return Aimv2Output(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
__all__ = ["Aimv2VisionModel", "Aimv2Model", "Aimv2PreTrainedModel", "Aimv2TextModel"]

View File

@ -0,0 +1,728 @@
# coding=utf-8
# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytorch implementation of AIMv2 Model"""
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...utils import (
auto_docstring,
can_return_tuple,
)
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoder, SiglipOutput
class Aimv2VisionConfig(SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a
AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2
[apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 2816):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries, keys and values.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the Linear layers or Not.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not.
is_native (`str`, *optional*, defaults to `False`):
Whether to use ckpt trained for image native resolution or not.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration
>>> configuration = Aimv2VisionConfig()
>>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration
>>> model = Aimv2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 2816,
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 14,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
qkv_bias: bool = False,
mlp_bias: bool = False,
hidden_act: str = "silu",
initializer_range: float = 0.02,
use_head: bool = True,
is_native: bool = False,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
hidden_act=hidden_act,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
qkv_bias=qkv_bias,
**kwargs,
)
self.use_head = use_head
self.initializer_range = initializer_range
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.is_native = is_native
del self.layer_norm_eps
class Aimv2TextConfig(SiglipTextConfig):
r"""
This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a
AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 49408):
Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`Aimv2Model`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 2048):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries, keys and values.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the Linear layers or Not.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
max_position_embeddings (`int`, *optional*, defaults to 77):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the for initializing all weight matrices.
"""
def __init__(
self,
vocab_size: int = 49408,
hidden_size: int = 768,
intermediate_size: int = 2048,
num_hidden_layers: int = 12,
num_attention_heads: int = 6,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
qkv_bias: bool = False,
mlp_bias: bool = False,
hidden_act: str = "silu",
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: int = 49407,
max_position_embeddings: int = 77,
initializer_range: bool = 0.02,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.initializer_range = initializer_range
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
del self.bos_token_id
del self.pad_token_id
del self.projection_size
del self.layer_norm_eps
class Aimv2Config(SiglipConfig):
r"""
[`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to
instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs.
Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Aimv2TextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Aimv2VisionConfig`].
projection_dim (`int`, *optional*, defaults to 512):
Dimensionality of text and vision projection layers.
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
The initial value of the *logit_scale* parameter.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import Aimv2Config, Aimv2Model
>>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration
>>> configuration = Aimv2Config()
>>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration
>>> model = Aimv2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig
>>> from transformers import Aimv2TextConfig, Aimv2VisionConfig
>>> # Initializing a AIMv2Text and AIMv2Vision configuration
>>> config_text = Aimv2TextConfig()
>>> config_vision = Aimv2VisionConfig()
>>> config = Aimv2Config(text_config=config_text, vision_config=config_vision)
```"""
def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
):
super().__init__(text_config, vision_config, **kwargs)
self.projection_dim = projection_dim
self.logit_scale_init_value = logit_scale_init_value
self.max_logit_scale = 100.0
del self.initializer_factor
class Aimv2Output(SiglipOutput):
pass
class Aimv2RMSNorm(LlamaRMSNorm):
pass
class Aimv2MLP(LlamaMLP):
pass
class Aimv2VisionEmbeddings(nn.Module):
def __init__(self, config: Aimv2VisionConfig):
super().__init__()
self.config = config
self.patch_size = config.patch_size
self.patch_embed = nn.Conv2d(
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
)
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2
if not self.config.is_native:
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
@staticmethod
def build_2d_sincos_position_embedding(
height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
) -> torch.Tensor:
grid_w = torch.arange(int(width), dtype=dtype, device=device)
grid_h = torch.arange(int(height), dtype=dtype, device=device)
grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
omega = 1.0 / (temperature**omega)
out_h = grid_h.flatten()[..., None] @ omega[None, :]
out_w = grid_w.flatten()[..., None] @ omega[None, :]
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
_, _, height, width = pixel_values.size()
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
hidden_states = self.rms_norm(hidden_states)
if self.config.is_native:
pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size,
width // self.patch_size,
embed_dim=self.config.hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
pos_embed = self.position_embedding(self.position_ids)
hidden_states = hidden_states + pos_embed
return hidden_states
class Aimv2TextEmbeddings(CLIPTextEmbeddings):
pass
class Aimv2Attention(SiglipAttention):
def __init__(self, config):
super().__init__(config)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
class Aimv2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Aimv2VisionConfig):
super().__init__()
self.attention = Aimv2Attention(config)
self.ffn = Aimv2MLP(config)
self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states = self.rms_norm1(hidden_states)
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)
hidden_states = hidden_states + attn_output
norm_hidden_states = self.rms_norm2(hidden_states)
mlp_output = self.ffn(norm_hidden_states)
hidden_states = hidden_states + mlp_output
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
class Aimv2Encoder(SiglipEncoder):
pass
class Aimv2AttentionPoolingHead(nn.Module):
def __init__(self, config: Aimv2VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, hidden_dim = hidden_states.shape
cls_token = self.cls_token.expand(batch_size, -1, -1)
key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(query, key, value)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
attn_output = attn_output.mean(dim=1)
output = self.output_proj(attn_output)
return output
@auto_docstring
class Aimv2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models. The model is only intended for inference and doesn't support finetuning.
"""
config_class = Aimv2Config
base_model_prefix = "aimv2"
supports_gradient_checkpointing = True
_no_split_modules = [
"Aimv2EncoderLayer",
"Aimv2AttentionPoolingHead",
"Aimv2VisionEmbeddings",
"Aimv2TextEmbeddings",
]
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
def _init_weights(self, module):
std = (
self.config.vision_config.initializer_range
if hasattr(self.config, "vision_config")
else self.config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, Aimv2RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, Aimv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=std)
@auto_docstring(
custom_intro="""
The Vision model from AIMv2 without any head or projection on top.
"""
)
class Aimv2VisionModel(Aimv2PreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Aimv2VisionConfig):
super().__init__(config)
self.config = config
self.embeddings = Aimv2VisionEmbeddings(config)
self.encoder = Aimv2Encoder(config)
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.use_head = config.use_head
if self.use_head:
self.head = Aimv2AttentionPoolingHead(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embed
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Siglip2VisionModel
>>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.rms_norm(last_hidden_state)
pooler_output = self.head(last_hidden_state) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@auto_docstring(
custom_intro="""
The text model from AIMv2 without any head or projection on top.
"""
)
class Aimv2TextModel(Aimv2PreTrainedModel):
main_input_name = "input_ids"
def __init__(self, config: Aimv2TextConfig):
super().__init__(config)
self.config = config
self.embeddings = Aimv2TextEmbeddings(config)
self.encoder = Aimv2Encoder(config)
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.eos_token_id = config.eos_token_id
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.token_embedding
def set_input_embeddings(self, value):
self.embeddings.token_embedding = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(input_ids)
_, seq_len, _ = hidden_states.shape
cache_position = torch.arange(seq_len, device=hidden_states.device)
if attention_mask is not None:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.rms_norm(last_hidden_state)
# Get pooled output
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@auto_docstring
class Aimv2Model(CLIPModel, nn.Module):
def __init__(self, config: Aimv2Config):
nn.Module().__init__(config)
self.projection_dim = config.projection_dim
self.vision_embed_dim = config.vision_config.hidden_size
self.text_embed_dim = config.text_config.hidden_size
self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
self.text_model = Aimv2TextModel._from_config(config.text_config)
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.max_log_logit_scale = math.log(config.max_logit_scale)
self.post_init()
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Aimv2Output:
r"""
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Aimv2Model
>>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
image_embeds = vision_outputs.pooler_output
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs.pooler_output
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds)
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
logits_per_image = logits_per_text.t()
return Aimv2Output(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
__all__ = [
"Aimv2Config",
"Aimv2VisionConfig",
"Aimv2TextConfig",
"Aimv2VisionModel",
"Aimv2Model",
"Aimv2PreTrainedModel",
"Aimv2TextModel",
]

View File

@ -36,6 +36,8 @@ _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
[
# Add configs here
("aimv2", "Aimv2Config"),
("aimv2_vision_model", "Aimv2VisionConfig"),
("albert", "AlbertConfig"),
("align", "AlignConfig"),
("altclip", "AltCLIPConfig"),
@ -406,6 +408,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
MODEL_NAMES_MAPPING = OrderedDict[str, str](
[
# Add full (and cased) model names here
("aimv2", "AIMv2"),
("aimv2_vision_model", "Aimv2VisionModel"),
("albert", "ALBERT"),
("align", "ALIGN"),
("altclip", "AltCLIP"),
@ -857,6 +861,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
("glm4v_text", "glm4v"),
("idefics3_vision", "idefics3"),
("siglip_vision_model", "siglip"),
("aimv2_vision_model", "aimv2"),
("smolvlm_vision", "smolvlm"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),

View File

@ -56,6 +56,8 @@ if TYPE_CHECKING:
else:
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("aria", ("AriaImageProcessor")),
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),

View File

@ -32,6 +32,8 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("aimv2", "Aimv2Model"),
("aimv2_vision_model", "Aimv2VisionModel"),
("albert", "AlbertModel"),
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
@ -676,6 +678,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
[
# Model for Image mapping
("aimv2_vision_model", "Aimv2VisionModel"),
("beit", "BeitModel"),
("bit", "BitModel"),
("conditional_detr", "ConditionalDetrModel"),

View File

@ -45,6 +45,7 @@ logger = logging.get_logger(__name__)
PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("aimv2", "CLIPProcessor"),
("align", "AlignProcessor"),
("altclip", "AltCLIPProcessor"),
("aria", "AriaProcessor"),

View File

@ -56,6 +56,13 @@ logger = logging.get_logger(__name__)
# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
[
(
"aimv2",
(
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"albert",
(

View File

@ -475,7 +475,6 @@ class SiglipPreTrainedModel(PreTrainedModel):
_no_split_modules = [
"SiglipTextEmbeddings",
"SiglipEncoderLayer",
"SiglipVisionEmbeddings",
"SiglipEncoderLayer",
"SiglipMultiheadAttentionPoolingHead",

View File

@ -708,7 +708,6 @@ class Siglip2PreTrainedModel(PreTrainedModel):
_no_split_modules = [
"Siglip2TextEmbeddings",
"Siglip2EncoderLayer",
"Siglip2VisionEmbeddings",
"Siglip2EncoderLayer",
"Siglip2MultiheadAttentionPoolingHead",

View File

View File

@ -0,0 +1,742 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch AIMv2 model."""
import inspect
import os
import tempfile
import unittest
import numpy as np
import requests
from parameterized import parameterized
from pytest import mark
from transformers import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
is_torch_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
require_torch_sdpa,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import (
Aimv2Model,
Aimv2TextModel,
Aimv2VisionModel,
)
if is_vision_available():
from PIL import Image
from transformers import AutoImageProcessor, AutoProcessor
class Aimv2VisionModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=False,
hidden_size=32,
projection_dim=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return Aimv2VisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
)
def create_and_check_model(self, config, pixel_values):
model = Aimv2VisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
class Aimv2ModelTesterMixin(ModelTesterMixin):
"""
Subclass of ModelTesterMixin with methods specific to testing Aimv2 models.
The SDPA equivalence test is overridden here because Aimv2 models may have test/vision/text+vision inputs,
different output logits, and are not supposed to be used or tested with padding_side="left".
"""
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
if hasattr(model_sdpa, "vision_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
if hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
@require_torch
class Aimv2VisionModelTest(Aimv2ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as Aimv2 does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (Aimv2VisionModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = Aimv2VisionModelTester(self)
self.config_tester = ConfigTester(
self, config_class=Aimv2VisionConfig, has_text_modality=False, hidden_size=37
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="Aimv2 does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
class Aimv2TextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=False,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
projection_dim=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, input_mask
def get_config(self):
return Aimv2TextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = Aimv2TextModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class Aimv2TextModelTest(Aimv2ModelTesterMixin, unittest.TestCase):
all_model_classes = (Aimv2TextModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
test_resize_embeddings = False
def setUp(self):
self.model_tester = Aimv2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Aimv2TextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Aimv2 does not use inputs_embeds")
def test_inputs_embeds(self):
pass
class Aimv2ModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=False):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = Aimv2TextModelTester(parent, **text_kwargs)
self.vision_model_tester = Aimv2VisionModelTester(parent, **vision_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return Aimv2Config.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = Aimv2Model(config).to(torch_device).eval()
with torch.no_grad():
result = model(input_ids, pixel_values, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"return_loss": False,
}
return config, inputs_dict
@require_torch
class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
additional_model_inputs = ["pixel_values"]
all_model_classes = (Aimv2Model,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": Aimv2Model, "image-feature-extraction": Aimv2VisionModel}
if is_torch_available()
else {}
)
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
_is_composite = True
def setUp(self):
self.model_tester = Aimv2ModelTester(self)
common_properties = ["projection_dim", "logit_scale_init_value"]
self.config_tester = ConfigTester(
self, config_class=Aimv2Config, has_text_modality=False, common_properties=common_properties
)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
print(config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Aimv2Model does not have input/output embeddings")
def test_model_get_set_embeddings(self):
pass
# Override as the `logit_scale` parameter initialization is different for Aimv2
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
# check if `logit_scale` is initialized as per the original implementation
if name == "logit_scale":
self.assertAlmostEqual(
param.data.item(),
np.log(1 / 0.07),
delta=1e-3,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save Aimv2Config and check if we can load Aimv2VisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = Aimv2VisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save Aimv2Config and check if we can load Aimv2TextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = Aimv2TextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
self.assertTrue(
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
)
self.assertTrue(
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
dummy_pixel_mask = inputs_dict["attention_mask"]
# right padding
dummy_pixel_mask[:] = 1
dummy_pixel_mask[:, -1:] = 0
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
logits_per_image_eager = outputs.logits_per_image[:, :-1]
logits_per_text_eager = outputs.logits_per_text[:, :-1]
logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
self.assertTrue(
torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
)
self.assertTrue(
torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
)
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@is_flaky()
def test_eager_matches_sdpa_inference(self, *args):
# Adding only flaky decorator here and call the parent test method
return getattr(ModelTesterMixin, self._testMethodName)(self)
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest._create_and_check_torchscript with CLIP->Aimv2
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
try:
input_ids = inputs_dict["input_ids"]
pixel_values = inputs_dict["pixel_values"] # Aimv2 needs pixel_values
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
@require_vision
@require_torch
class Aimv2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model_name = "yaswanthgali/aimv2-large-patch14-224-lit-HF"
model = Aimv2Model.from_pretrained(model_name, device_map="auto")
processor = AutoProcessor.from_pretrained(model_name)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(
text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt"
).to(model.device)
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Verify the logits
self.assertEqual(
outputs.logits_per_image.shape,
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
)
self.assertEqual(
outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
# handle device
expected_logits = torch.tensor([[33.3550, 26.4255]]).to(model.device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
@require_vision
@require_torch
class Aimv2VisionModelIntegrationTests(unittest.TestCase):
@slow
def test_inference(self):
model_name = "yaswanthgali/aimv2-large-patch14-224-HF"
model = Aimv2VisionModel.from_pretrained(model_name, device_map="auto")
processor = AutoImageProcessor.from_pretrained(model_name)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(image, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model(**inputs)
# Verify logits shape
self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 256, 1024]))
# Verify logits slice
# fmt: off
expected_logits = torch.tensor(
[[ 0.0510, 0.0806, -0.0990, -0.0154],
[ 2.7850, -2.5143, -0.3320, 2.4196],
[ 2.8179, -2.4089, -0.2770, 2.3218],
[ 2.7641, -2.4114, -0.3684, 2.2998],
[ 2.7972, -2.3180, -0.4490, 2.2302],
[ 2.8584, -2.5322, -0.2302, 2.4936],
[-2.7849, 2.4121, 1.3670, -1.5514]]).to(model.device)
# fmt: on
output_slice = output.last_hidden_state.squeeze(0)[0:7, 0:4]
self.assertTrue(torch.allclose(output_slice, expected_logits, atol=1e-3))
@slow
def test_inference_for_native_resolution(self):
model_name = "yaswanthgali/aimv2-large-patch14-native-HF"
model = Aimv2VisionModel.from_pretrained(model_name, device_map="auto")
processor = AutoImageProcessor.from_pretrained(model_name)
image = image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
inputs = processor(image, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model(**inputs)
# Verify logits shape
self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 1530, 1024]))
# Verify logits slice
# fmt: off
expected_logits = torch.tensor(
[[-1.3342, 0.3720, 0.0963, 0.4159],
[-1.5328, 0.4677, 0.0936, 0.4321],
[-0.3775, -0.2758, -0.0803, -0.5367],
[-1.3877, 0.5561, -1.9064, -1.1766],
[-0.5148, 0.0108, -0.4515, -0.6402],
[-0.3400, -0.1711, -0.1855, -0.4219],
[-1.2877, -0.0585, -0.1646, 0.7420]]).to(model.device)
# fmt: on
output_slice = output.last_hidden_state.squeeze(0)[0:7, 0:4]
self.assertTrue(torch.allclose(output_slice, expected_logits, atol=1e-3))

View File

@ -196,6 +196,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"Aimv2TextModel",
"AlignTextModel",
"AlignVisionModel",
"ClapTextModel",