mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Merge d7d23b9d28
into 2d561713f8
This commit is contained in:
commit
73df42d8f2
@ -693,6 +693,8 @@
|
||||
title: Zamba2
|
||||
title: Text models
|
||||
- sections:
|
||||
- local: model_doc/aimv2
|
||||
title: Aimv2
|
||||
- local: model_doc/beit
|
||||
title: BEiT
|
||||
- local: model_doc/bit
|
||||
|
104
docs/source/en/model_doc/aimv2.md
Normal file
104
docs/source/en/model_doc/aimv2.md
Normal file
@ -0,0 +1,104 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# AIMv2
|
||||
|
||||
## Overview
|
||||
|
||||
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce a novel method for pre-training of large-scale vision encoders. Building on recent advancements in autoregressive pre-training of vision models, we extend this framework to a multimodal setting, i.e., images and text. In this paper, we present AIMV2, a family of generalist vision encoders characterized by a straightforward pre-training process, scalability, and remarkable performance across a range of downstream tasks. This is achieved by pairing the vision encoder with a multimodal decoder that autoregressively generates raw image patches and text tokens. Our encoders excel not only in multimodal evaluations but also in vision benchmarks such as localization, grounding, and classification. Notably, our AIMV2-3B encoder achieves 89.5% accuracy on ImageNet-1k with a frozen trunk. Furthermore, AIMV2 consistently outperforms state-of-the-art contrastive models (e.g., CLIP, SigLIP) in multimodal image understanding across diverse settings.*
|
||||
|
||||
|
||||
This model was contributed by [Yaswanth Gali](https://huggingface.co/yaswanthgali).
|
||||
The original code can be found [here](https://github.com/apple/ml-aim).
|
||||
|
||||
## Usage Example
|
||||
|
||||
Here is an example of Image Feature Extraction using specific checkpoints on resized images and native resolution images:
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
Here is an example of a checkpoint performing zero-shot classification:
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
|
||||
|
||||
processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
|
||||
inputs = processor(
|
||||
images=image,
|
||||
text=text,
|
||||
add_special_tokens=True,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
outputs = model(**inputs)
|
||||
probs = outputs.logits_per_image.softmax(dim=-1)
|
||||
```
|
||||
|
||||
## Aimv2Config
|
||||
|
||||
[[autodoc]] Aimv2Config
|
||||
|
||||
## Aimv2TextConfig
|
||||
|
||||
[[autodoc]] Aimv2TextConfig
|
||||
|
||||
## Aimv2VisionConfig
|
||||
|
||||
[[autodoc]] Aimv2VisionConfig
|
||||
|
||||
## Aimv2Model
|
||||
|
||||
[[autodoc]] Aimv2Model
|
||||
- forward
|
||||
|
||||
## Aimv2VisionModel
|
||||
|
||||
[[autodoc]] Aimv2VisionModel
|
||||
- forward
|
||||
|
||||
## Aimv2TextModel
|
||||
|
||||
[[autodoc]] Aimv2TextModel
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
@ -18,6 +18,7 @@ from ..utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .aimv2 import *
|
||||
from .albert import *
|
||||
from .align import *
|
||||
from .altclip import *
|
||||
|
27
src/transformers/models/aimv2/__init__.py
Normal file
27
src/transformers/models/aimv2/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_aimv2 import *
|
||||
from .modeling_aimv2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
296
src/transformers/models/aimv2/configuration_aimv2.py
Normal file
296
src/transformers/models/aimv2/configuration_aimv2.py
Normal file
@ -0,0 +1,296 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_aimv2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Aimv2VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a
|
||||
AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2
|
||||
[apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 2816):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the Linear layers or Not.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the for initializing all weight matrices.
|
||||
use_head (`str`, *optional*, defaults to `True`):
|
||||
Whether to use Attention Pooling Head or Not.
|
||||
is_native (`str`, *optional*, defaults to `False`):
|
||||
Whether to use ckpt trained for image native resolution or not.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
||||
|
||||
>>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration
|
||||
>>> configuration = Aimv2VisionConfig()
|
||||
|
||||
>>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration
|
||||
>>> model = Aimv2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "aimv2_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
intermediate_size: int = 2816,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 8,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 224,
|
||||
patch_size: int = 14,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
attention_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
hidden_act: str = "silu",
|
||||
initializer_range: float = 0.02,
|
||||
use_head: bool = True,
|
||||
is_native: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
self.use_head = use_head
|
||||
self.initializer_range = initializer_range
|
||||
self.mlp_bias = mlp_bias
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.is_native = is_native
|
||||
|
||||
|
||||
class Aimv2TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a
|
||||
AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2
|
||||
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 49408):
|
||||
Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`Aimv2Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 6):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the Linear layers or Not.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the padding token in the vocabulary.
|
||||
bos_token_id (`int`, *optional*, defaults to 49406):
|
||||
The id of the beginning-of-sequence token in the vocabulary.
|
||||
eos_token_id (`int`, *optional*, defaults to 49407):
|
||||
The id of the end-of-sequence token in the vocabulary.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 77):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the for initializing all weight matrices.
|
||||
"""
|
||||
|
||||
model_type = "aimv2_text_model"
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 49408,
|
||||
hidden_size: int = 768,
|
||||
intermediate_size: int = 2048,
|
||||
num_hidden_layers: int = 12,
|
||||
num_attention_heads: int = 6,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
attention_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
hidden_act: str = "silu",
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: int = 49407,
|
||||
max_position_embeddings: int = 77,
|
||||
initializer_range: bool = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
self.mlp_bias = mlp_bias
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
|
||||
class Aimv2Config(PretrainedConfig):
|
||||
r"""
|
||||
[`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to
|
||||
instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2
|
||||
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Aimv2TextConfig`].
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Aimv2VisionConfig`].
|
||||
projection_dim (`int`, *optional*, defaults to 512):
|
||||
Dimensionality of text and vision projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The initial value of the *logit_scale* parameter.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Aimv2Config, Aimv2Model
|
||||
|
||||
>>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration
|
||||
>>> configuration = Aimv2Config()
|
||||
|
||||
>>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration
|
||||
>>> model = Aimv2Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
>>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig
|
||||
>>> from transformers import Aimv2TextConfig, Aimv2VisionConfig
|
||||
|
||||
>>> # Initializing a AIMv2Text and AIMv2Vision configuration
|
||||
>>> config_text = Aimv2TextConfig()
|
||||
>>> config_vision = Aimv2VisionConfig()
|
||||
|
||||
>>> config = Aimv2Config(text_config=config_text, vision_config=config_vision)
|
||||
```"""
|
||||
|
||||
model_type = "aimv2"
|
||||
sub_configs = {"text_config": Aimv2TextConfig, "vision_config": Aimv2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
logger.info("`text_config` is `None`. Initializing the `Aimv2TextConfig` with default values.")
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
logger.info("`vision_config` is `None`. initializing the `Aimv2VisionConfig` with default values.")
|
||||
|
||||
self.text_config = Aimv2TextConfig(**text_config)
|
||||
self.vision_config = Aimv2VisionConfig(**vision_config)
|
||||
self.projection_dim = projection_dim
|
||||
self.logit_scale_init_value = logit_scale_init_value
|
||||
self.max_logit_scale = 100.0
|
||||
|
||||
@classmethod
|
||||
def from_text_vision_configs(cls, text_config: Aimv2TextConfig, vision_config: Aimv2VisionConfig, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`Aimv2Config`] (or a derived class) from aimv2 text model configuration and aimv2 vision
|
||||
model configuration.
|
||||
|
||||
Returns:
|
||||
[`Aimv2Config`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Aimv2Config", "Aimv2VisionConfig", "Aimv2TextConfig"]
|
@ -0,0 +1,269 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
Aimv2Config,
|
||||
Aimv2Model,
|
||||
Aimv2VisionConfig,
|
||||
Aimv2VisionModel,
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL = {
|
||||
# Embeddings
|
||||
r"preprocessor.patchifier.proj": r"embeddings.patch_embed",
|
||||
r"preprocessor.pos_embed": r"embeddings.position_embedding.weight",
|
||||
r"preprocessor.patchifier.norm.weight": r"embeddings.rms_norm.weight",
|
||||
# Encoder Layers
|
||||
r"trunk.blocks.(\d+).attn.qkv": r"encoder.layers.\1.attention.qkv",
|
||||
r"trunk.blocks.(\d+).attn.proj": r"encoder.layers.\1.attention.out_proj",
|
||||
r"trunk.blocks.(\d+).mlp.fc1": r"encoder.layers.\1.ffn.gate_proj",
|
||||
r"trunk.blocks.(\d+).mlp.fc2": r"encoder.layers.\1.ffn.down_proj",
|
||||
r"trunk.blocks.(\d+).mlp.fc3": r"encoder.layers.\1.ffn.up_proj",
|
||||
# Normalization Layers
|
||||
r"trunk.blocks.(\d+).norm_1": r"encoder.layers.\1.rms_norm1",
|
||||
r"trunk.blocks.(\d+).norm_2": r"encoder.layers.\1.rms_norm2",
|
||||
# Final Norm
|
||||
r"trunk.post_trunk_norm": r"rms_norm",
|
||||
}
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
# Vision Embeddings
|
||||
r"image_encoder.preprocessor.patchifier.proj": r"vision_model.embeddings.patch_embed",
|
||||
r"image_encoder.preprocessor.pos_embed": r"vision_model.embeddings.position_embedding.weight",
|
||||
r"image_encoder.preprocessor.patchifier.norm.weight": r"vision_model.embeddings.rms_norm.weight",
|
||||
# Vision Encoder Layers
|
||||
r"image_encoder.trunk.blocks.(\d+).attn.qkv": r"vision_model.encoder.layers.\1.attention.qkv",
|
||||
r"image_encoder.trunk.blocks.(\d+).attn.proj": r"vision_model.encoder.layers.\1.attention.out_proj",
|
||||
r"image_encoder.trunk.blocks.(\d+).mlp.fc1": r"vision_model.encoder.layers.\1.ffn.gate_proj",
|
||||
r"image_encoder.trunk.blocks.(\d+).mlp.fc2": r"vision_model.encoder.layers.\1.ffn.down_proj",
|
||||
r"image_encoder.trunk.blocks.(\d+).mlp.fc3": r"vision_model.encoder.layers.\1.ffn.up_proj",
|
||||
# Normalization Layers
|
||||
r"image_encoder.trunk.blocks.(\d+).norm_1": r"vision_model.encoder.layers.\1.rms_norm1",
|
||||
r"image_encoder.trunk.blocks.(\d+).norm_2": r"vision_model.encoder.layers.\1.rms_norm2",
|
||||
r"image_encoder.trunk.post_trunk_norm": r"vision_model.rms_norm",
|
||||
r"image_projector": r"visual_projection",
|
||||
# Vision Head
|
||||
r"image_encoder.head.cls_token": r"vision_model.head.cls_token",
|
||||
r"image_encoder.head.k": r"vision_model.head.k_proj",
|
||||
r"image_encoder.head.v": r"vision_model.head.v_proj",
|
||||
r"image_encoder.head.linear": r"vision_model.head.output_proj",
|
||||
# Text Embeddings
|
||||
r"text_encoder.preprocessor.text_embedding.weight": r"text_model.embeddings.token_embedding.weight",
|
||||
r"text_encoder.preprocessor.positional_embedding": r"text_model.embeddings.position_embedding.weight",
|
||||
# Text Encoder Layers
|
||||
r"text_encoder.trunk.blocks.(\d+).attn.qkv": r"text_model.encoder.layers.\1.attention.qkv",
|
||||
r"text_encoder.trunk.blocks.(\d+).attn.proj": r"text_model.encoder.layers.\1.attention.out_proj",
|
||||
r"text_encoder.trunk.blocks.(\d+).mlp.fc1": r"text_model.encoder.layers.\1.ffn.gate_proj",
|
||||
r"text_encoder.trunk.blocks.(\d+).mlp.fc2": r"text_model.encoder.layers.\1.ffn.down_proj",
|
||||
r"text_encoder.trunk.blocks.(\d+).mlp.fc3": r"text_model.encoder.layers.\1.ffn.up_proj",
|
||||
# Text Normalization Layers
|
||||
r"text_encoder.trunk.blocks.(\d+).norm_1": r"text_model.encoder.layers.\1.rms_norm1",
|
||||
r"text_encoder.trunk.blocks.(\d+).norm_2": r"text_model.encoder.layers.\1.rms_norm2",
|
||||
r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
|
||||
r"text_projector": r"text_projection",
|
||||
r"log_logit_scale": r"logit_scale",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
|
||||
# Download only the model.safetensors file
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["model.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
safetensor_path = f"{directory_path}/model.safetensors"
|
||||
|
||||
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: dict, ORIGINAL_TO_CONVERTED_KEY_MAPPING: dict):
|
||||
"""Converts state dict keys from the old format to the new format."""
|
||||
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def split_qkv_tensor(key, tensor):
|
||||
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
|
||||
|
||||
new_keys = ["q_proj", "k_proj", "v_proj"]
|
||||
split_size = tensor.shape[0] // 3
|
||||
split_tensors = torch.split(tensor, split_size, dim=0)
|
||||
|
||||
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
|
||||
|
||||
|
||||
def get_model_config_mapping(model_id: str):
|
||||
"""Determines the correct model, config, and key mappings based on the checkpoint name."""
|
||||
|
||||
if model_id == "apple/aimv2-large-patch14-224-lit":
|
||||
return Aimv2Model, Aimv2Config, ORIGINAL_TO_CONVERTED_KEY_MAPPING
|
||||
else:
|
||||
return Aimv2VisionModel, Aimv2VisionConfig, ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL
|
||||
|
||||
|
||||
def write_model(
|
||||
hf_repo_id: str,
|
||||
output_dir: str,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
"""
|
||||
Converts a model checkpoint to Hugging Face format and saves it.
|
||||
|
||||
Args:
|
||||
hf_repo_id (str): The Hugging Face repo ID to load from.
|
||||
output_dir (str): The directory to save the converted model.
|
||||
safe_serialization (bool): Whether to use safe serialization.
|
||||
|
||||
Returns:
|
||||
model: The reloaded Hugging Face model.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Get the appropriate model, config, and key mapping
|
||||
model_class, config_class, key_mapping = get_model_config_mapping(hf_repo_id)
|
||||
|
||||
# Load config and original state dict
|
||||
config = config_class.from_pretrained(hf_repo_id)
|
||||
|
||||
# Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
|
||||
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
|
||||
config.use_head = False
|
||||
|
||||
if hf_repo_id == "apple/aimv2-large-patch14-native":
|
||||
config.is_native = True
|
||||
|
||||
original_state_dict = load_original_state_dict(hf_repo_id)
|
||||
|
||||
print("Converting model...")
|
||||
|
||||
state_dict = {}
|
||||
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
|
||||
all_keys = list(original_state_dict.keys())
|
||||
|
||||
for key in all_keys:
|
||||
value = original_state_dict[key]
|
||||
new_key = result.pop(key)
|
||||
|
||||
if "qkv" in new_key:
|
||||
qkv_state_dict = split_qkv_tensor(new_key, value)
|
||||
state_dict.update(qkv_state_dict)
|
||||
else:
|
||||
state_dict[new_key] = value
|
||||
|
||||
# Check if position embeddings exist before squeezing
|
||||
if new_key.endswith("position_embedding.weight"):
|
||||
state_dict[new_key] = value.squeeze(0)
|
||||
|
||||
print(f"Loading the checkpoint in a {model_class.__name__}.")
|
||||
model = model_class(config)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
gc.collect()
|
||||
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
model = model_class.from_pretrained(output_dir, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
return model
|
||||
|
||||
|
||||
def write_image_processor(hf_repo_id: str, output_dir: str):
|
||||
if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
|
||||
image_processor = AutoProcessor.from_pretrained(hf_repo_id, use_fast=True)
|
||||
else:
|
||||
image_processor = AutoImageProcessor.from_pretrained(hf_repo_id, use_fast=True)
|
||||
image_processor.save_pretrained(output_dir)
|
||||
return image_processor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--hf_repo_id",
|
||||
default="apple/aimv2-large-patch14-224",
|
||||
help="Location of official weights from apple on HF",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="aimv2_model",
|
||||
help="Location to write the converted model and processor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Whether or not to push the converted model to the huggingface hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_repo_id",
|
||||
default=None,
|
||||
help="Huggingface hub repo to write the converted model and processor",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = write_model(
|
||||
hf_repo_id=args.hf_repo_id,
|
||||
output_dir=args.output_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
)
|
||||
|
||||
image_processor = write_image_processor(
|
||||
hf_repo_id=args.hf_repo_id,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
print("Pushing to hub...")
|
||||
model.push_to_hub(args.hub_repo_id)
|
||||
image_processor.push_to_hub(args.hub_repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
834
src/transformers/models/aimv2/modeling_aimv2.py
Normal file
834
src/transformers/models/aimv2/modeling_aimv2.py
Normal file
@ -0,0 +1,834 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_aimv2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple
|
||||
from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring
|
||||
class Aimv2Output(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||
Contrastive loss for image-text similarity.
|
||||
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||
similarity scores.
|
||||
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||
similarity scores.
|
||||
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`Aimv2TextModel`].
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The image embeddings obtained by applying the projection layer to the pooled output of [`Aimv2VisionModel`].
|
||||
text_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`Aimv2TextModel`].
|
||||
vision_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`Aimv2VisionModel`].
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits_per_image: Optional[torch.FloatTensor] = None
|
||||
logits_per_text: Optional[torch.FloatTensor] = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPooling = None
|
||||
|
||||
def to_tuple(self) -> tuple[Any]:
|
||||
return tuple(
|
||||
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||||
for k in self.keys()
|
||||
)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Aimv2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Aimv2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Aimv2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Aimv2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.patch_size = config.patch_size
|
||||
self.patch_embed = nn.Conv2d(
|
||||
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
|
||||
)
|
||||
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
if not self.config.is_native:
|
||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def build_2d_sincos_position_embedding(
|
||||
height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
||||
) -> torch.Tensor:
|
||||
grid_w = torch.arange(int(width), dtype=dtype, device=device)
|
||||
grid_h = torch.arange(int(height), dtype=dtype, device=device)
|
||||
grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
|
||||
pos_dim = embed_dim // 4
|
||||
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
|
||||
omega = 1.0 / (temperature**omega)
|
||||
|
||||
out_h = grid_h.flatten()[..., None] @ omega[None, :]
|
||||
out_w = grid_w.flatten()[..., None] @ omega[None, :]
|
||||
|
||||
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.size()
|
||||
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
||||
hidden_states = self.rms_norm(hidden_states)
|
||||
|
||||
if self.config.is_native:
|
||||
pos_embed = self.build_2d_sincos_position_embedding(
|
||||
height // self.patch_size,
|
||||
width // self.patch_size,
|
||||
embed_dim=self.config.hidden_size,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
else:
|
||||
pos_embed = self.position_embedding(self.position_ids)
|
||||
|
||||
hidden_states = hidden_states + pos_embed
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Aimv2TextEmbeddings(nn.Module):
|
||||
def __init__(self, config: Aimv2TextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
max_position_embedding = self.position_embedding.weight.shape[0]
|
||||
|
||||
if seq_length > max_position_embedding:
|
||||
raise ValueError(
|
||||
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
||||
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Aimv2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Aimv2EncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__()
|
||||
self.attention = Aimv2Attention(config)
|
||||
self.ffn = Aimv2MLP(config)
|
||||
self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states = self.rms_norm1(hidden_states)
|
||||
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = hidden_states + attn_output
|
||||
norm_hidden_states = self.rms_norm2(hidden_states)
|
||||
mlp_output = self.ffn(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + mlp_output
|
||||
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
|
||||
|
||||
|
||||
class Aimv2Encoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`Aimv2EncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: Aimv2Config
|
||||
"""
|
||||
|
||||
def __init__(self, config: Aimv2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Ignore copy
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
class Aimv2AttentionPoolingHead(nn.Module):
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
||||
self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||
|
||||
key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
|
||||
value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
|
||||
query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
|
||||
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
|
||||
attn_output = attn_output.mean(dim=1)
|
||||
|
||||
output = self.output_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models. The model is only intended for inference and doesn't support finetuning.
|
||||
"""
|
||||
|
||||
config_class = Aimv2Config
|
||||
base_model_prefix = "aimv2"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Aimv2EncoderLayer",
|
||||
"Aimv2AttentionPoolingHead",
|
||||
"Aimv2VisionEmbeddings",
|
||||
"Aimv2TextEmbeddings",
|
||||
]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.vision_config.initializer_range
|
||||
if hasattr(self.config, "vision_config")
|
||||
else self.config.initializer_range
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Aimv2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Vision model from AIMv2 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class Aimv2VisionModel(Aimv2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Aimv2VisionEmbeddings(config)
|
||||
self.encoder = Aimv2Encoder(config)
|
||||
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
|
||||
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
self.use_head = config.use_head
|
||||
if self.use_head:
|
||||
self.head = Aimv2AttentionPoolingHead(config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.patch_embed
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> BaseModelOutputWithPooling:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Siglip2VisionModel
|
||||
|
||||
>>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled features
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.rms_norm(last_hidden_state)
|
||||
|
||||
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The text model from AIMv2 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class Aimv2TextModel(Aimv2PreTrainedModel):
|
||||
main_input_name = "input_ids"
|
||||
|
||||
def __init__(self, config: Aimv2TextConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Aimv2TextEmbeddings(config)
|
||||
self.encoder = Aimv2Encoder(config)
|
||||
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.token_embedding = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> BaseModelOutputWithPooling:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
_, seq_len, _ = hidden_states.shape
|
||||
|
||||
cache_position = torch.arange(seq_len, device=hidden_states.device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=None,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.rms_norm(last_hidden_state)
|
||||
|
||||
# Get pooled output
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
|
||||
]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
|
||||
model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
|
||||
"""
|
||||
square_tensor = torch.pow(tensor, 2)
|
||||
sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
|
||||
normed_tensor = torch.pow(sum_tensor, 0.5)
|
||||
return normed_tensor
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Aimv2Model(Aimv2PreTrainedModel):
|
||||
config_class = Aimv2Config
|
||||
_no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"]
|
||||
|
||||
def __init__(self, config: Aimv2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
self.vision_embed_dim = config.vision_config.hidden_size
|
||||
self.text_embed_dim = config.text_config.hidden_size
|
||||
|
||||
self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
|
||||
self.text_model = Aimv2TextModel._from_config(config.text_config)
|
||||
|
||||
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
||||
self.max_log_logit_scale = math.log(config.max_logit_scale)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`Aimv2TextModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Aimv2Model
|
||||
|
||||
>>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/aimv2-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
# Use AIMV2 model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
text_outputs: BaseModelOutputWithPooling = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs.pooler_output
|
||||
text_features = self.text_projection(pooled_output)
|
||||
|
||||
return text_features
|
||||
|
||||
@auto_docstring
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`Aimv2VisionModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Aimv2Model
|
||||
|
||||
>>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/aimv2-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> image_features = model.get_image_features(**inputs)
|
||||
```"""
|
||||
# Use AIMV2 model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs.pooler_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_loss: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Aimv2Output:
|
||||
r"""
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Aimv2Model
|
||||
|
||||
>>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
text_outputs: BaseModelOutputWithPooling = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs.pooler_output
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs.pooler_output
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||
|
||||
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
|
||||
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
return Aimv2Output(
|
||||
logits_per_image=logits_per_image,
|
||||
logits_per_text=logits_per_text,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Aimv2VisionModel", "Aimv2Model", "Aimv2PreTrainedModel", "Aimv2TextModel"]
|
728
src/transformers/models/aimv2/modular_aimv2.py
Normal file
728
src/transformers/models/aimv2/modular_aimv2.py
Normal file
@ -0,0 +1,728 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pytorch implementation of AIMv2 Model"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
)
|
||||
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
|
||||
from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
|
||||
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoder, SiglipOutput
|
||||
|
||||
|
||||
class Aimv2VisionConfig(SiglipVisionConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a
|
||||
AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2
|
||||
[apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 2816):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the Linear layers or Not.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the for initializing all weight matrices.
|
||||
use_head (`str`, *optional*, defaults to `True`):
|
||||
Whether to use Attention Pooling Head or Not.
|
||||
is_native (`str`, *optional*, defaults to `False`):
|
||||
Whether to use ckpt trained for image native resolution or not.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
||||
|
||||
>>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration
|
||||
>>> configuration = Aimv2VisionConfig()
|
||||
|
||||
>>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration
|
||||
>>> model = Aimv2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
intermediate_size: int = 2816,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 8,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 224,
|
||||
patch_size: int = 14,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
attention_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
hidden_act: str = "silu",
|
||||
initializer_range: float = 0.02,
|
||||
use_head: bool = True,
|
||||
is_native: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
hidden_act=hidden_act,
|
||||
num_channels=num_channels,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
qkv_bias=qkv_bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.use_head = use_head
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.is_native = is_native
|
||||
|
||||
del self.layer_norm_eps
|
||||
|
||||
|
||||
class Aimv2TextConfig(SiglipTextConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a
|
||||
AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2
|
||||
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 49408):
|
||||
Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`Aimv2Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 6):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the Linear layers or Not.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the padding token in the vocabulary.
|
||||
bos_token_id (`int`, *optional*, defaults to 49406):
|
||||
The id of the beginning-of-sequence token in the vocabulary.
|
||||
eos_token_id (`int`, *optional*, defaults to 49407):
|
||||
The id of the end-of-sequence token in the vocabulary.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 77):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the for initializing all weight matrices.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 49408,
|
||||
hidden_size: int = 768,
|
||||
intermediate_size: int = 2048,
|
||||
num_hidden_layers: int = 12,
|
||||
num_attention_heads: int = 6,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
attention_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
hidden_act: str = "silu",
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: int = 49407,
|
||||
max_position_embeddings: int = 77,
|
||||
initializer_range: bool = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
hidden_act=hidden_act,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
del self.bos_token_id
|
||||
del self.pad_token_id
|
||||
del self.projection_size
|
||||
del self.layer_norm_eps
|
||||
|
||||
|
||||
class Aimv2Config(SiglipConfig):
|
||||
r"""
|
||||
[`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to
|
||||
instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2
|
||||
[apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Aimv2TextConfig`].
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Aimv2VisionConfig`].
|
||||
projection_dim (`int`, *optional*, defaults to 512):
|
||||
Dimensionality of text and vision projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The initial value of the *logit_scale* parameter.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Aimv2Config, Aimv2Model
|
||||
|
||||
>>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration
|
||||
>>> configuration = Aimv2Config()
|
||||
|
||||
>>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration
|
||||
>>> model = Aimv2Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
>>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig
|
||||
>>> from transformers import Aimv2TextConfig, Aimv2VisionConfig
|
||||
|
||||
>>> # Initializing a AIMv2Text and AIMv2Vision configuration
|
||||
>>> config_text = Aimv2TextConfig()
|
||||
>>> config_vision = Aimv2VisionConfig()
|
||||
|
||||
>>> config = Aimv2Config(text_config=config_text, vision_config=config_vision)
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
||||
):
|
||||
super().__init__(text_config, vision_config, **kwargs)
|
||||
self.projection_dim = projection_dim
|
||||
self.logit_scale_init_value = logit_scale_init_value
|
||||
self.max_logit_scale = 100.0
|
||||
|
||||
del self.initializer_factor
|
||||
|
||||
|
||||
class Aimv2Output(SiglipOutput):
|
||||
pass
|
||||
|
||||
|
||||
class Aimv2RMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Aimv2MLP(LlamaMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Aimv2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.patch_size = config.patch_size
|
||||
self.patch_embed = nn.Conv2d(
|
||||
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
|
||||
)
|
||||
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
if not self.config.is_native:
|
||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def build_2d_sincos_position_embedding(
|
||||
height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
||||
) -> torch.Tensor:
|
||||
grid_w = torch.arange(int(width), dtype=dtype, device=device)
|
||||
grid_h = torch.arange(int(height), dtype=dtype, device=device)
|
||||
grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
|
||||
pos_dim = embed_dim // 4
|
||||
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
|
||||
omega = 1.0 / (temperature**omega)
|
||||
|
||||
out_h = grid_h.flatten()[..., None] @ omega[None, :]
|
||||
out_w = grid_w.flatten()[..., None] @ omega[None, :]
|
||||
|
||||
return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.size()
|
||||
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
||||
hidden_states = self.rms_norm(hidden_states)
|
||||
|
||||
if self.config.is_native:
|
||||
pos_embed = self.build_2d_sincos_position_embedding(
|
||||
height // self.patch_size,
|
||||
width // self.patch_size,
|
||||
embed_dim=self.config.hidden_size,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
else:
|
||||
pos_embed = self.position_embedding(self.position_ids)
|
||||
|
||||
hidden_states = hidden_states + pos_embed
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Aimv2TextEmbeddings(CLIPTextEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class Aimv2Attention(SiglipAttention):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
|
||||
|
||||
class Aimv2EncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__()
|
||||
self.attention = Aimv2Attention(config)
|
||||
self.ffn = Aimv2MLP(config)
|
||||
self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states = self.rms_norm1(hidden_states)
|
||||
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = hidden_states + attn_output
|
||||
norm_hidden_states = self.rms_norm2(hidden_states)
|
||||
mlp_output = self.ffn(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + mlp_output
|
||||
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
|
||||
|
||||
|
||||
class Aimv2Encoder(SiglipEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class Aimv2AttentionPoolingHead(nn.Module):
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
||||
self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||
|
||||
key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
|
||||
value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
|
||||
query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
|
||||
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
|
||||
attn_output = attn_output.mean(dim=1)
|
||||
|
||||
output = self.output_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models. The model is only intended for inference and doesn't support finetuning.
|
||||
"""
|
||||
|
||||
config_class = Aimv2Config
|
||||
base_model_prefix = "aimv2"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Aimv2EncoderLayer",
|
||||
"Aimv2AttentionPoolingHead",
|
||||
"Aimv2VisionEmbeddings",
|
||||
"Aimv2TextEmbeddings",
|
||||
]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.vision_config.initializer_range
|
||||
if hasattr(self.config, "vision_config")
|
||||
else self.config.initializer_range
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, Aimv2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Vision model from AIMv2 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class Aimv2VisionModel(Aimv2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: Aimv2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Aimv2VisionEmbeddings(config)
|
||||
self.encoder = Aimv2Encoder(config)
|
||||
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
|
||||
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
self.use_head = config.use_head
|
||||
if self.use_head:
|
||||
self.head = Aimv2AttentionPoolingHead(config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.patch_embed
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> BaseModelOutputWithPooling:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Siglip2VisionModel
|
||||
|
||||
>>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled features
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.rms_norm(last_hidden_state)
|
||||
|
||||
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The text model from AIMv2 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class Aimv2TextModel(Aimv2PreTrainedModel):
|
||||
main_input_name = "input_ids"
|
||||
|
||||
def __init__(self, config: Aimv2TextConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Aimv2TextEmbeddings(config)
|
||||
self.encoder = Aimv2Encoder(config)
|
||||
self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.token_embedding = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> BaseModelOutputWithPooling:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
_, seq_len, _ = hidden_states.shape
|
||||
|
||||
cache_position = torch.arange(seq_len, device=hidden_states.device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=None,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.rms_norm(last_hidden_state)
|
||||
|
||||
# Get pooled output
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
|
||||
]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Aimv2Model(CLIPModel, nn.Module):
|
||||
def __init__(self, config: Aimv2Config):
|
||||
nn.Module().__init__(config)
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
self.vision_embed_dim = config.vision_config.hidden_size
|
||||
self.text_embed_dim = config.text_config.hidden_size
|
||||
|
||||
self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
|
||||
self.text_model = Aimv2TextModel._from_config(config.text_config)
|
||||
|
||||
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
||||
self.max_log_logit_scale = math.log(config.max_logit_scale)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_loss: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Aimv2Output:
|
||||
r"""
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Aimv2Model
|
||||
|
||||
>>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
>>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
text_outputs: BaseModelOutputWithPooling = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs.pooler_output
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs.pooler_output
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||
|
||||
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
|
||||
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
return Aimv2Output(
|
||||
logits_per_image=logits_per_image,
|
||||
logits_per_text=logits_per_text,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Aimv2Config",
|
||||
"Aimv2VisionConfig",
|
||||
"Aimv2TextConfig",
|
||||
"Aimv2VisionModel",
|
||||
"Aimv2Model",
|
||||
"Aimv2PreTrainedModel",
|
||||
"Aimv2TextModel",
|
||||
]
|
@ -36,6 +36,8 @@ _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
||||
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
[
|
||||
# Add configs here
|
||||
("aimv2", "Aimv2Config"),
|
||||
("aimv2_vision_model", "Aimv2VisionConfig"),
|
||||
("albert", "AlbertConfig"),
|
||||
("align", "AlignConfig"),
|
||||
("altclip", "AltCLIPConfig"),
|
||||
@ -406,6 +408,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("aimv2", "AIMv2"),
|
||||
("aimv2_vision_model", "Aimv2VisionModel"),
|
||||
("albert", "ALBERT"),
|
||||
("align", "ALIGN"),
|
||||
("altclip", "AltCLIP"),
|
||||
@ -857,6 +861,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||
("glm4v_text", "glm4v"),
|
||||
("idefics3_vision", "idefics3"),
|
||||
("siglip_vision_model", "siglip"),
|
||||
("aimv2_vision_model", "aimv2"),
|
||||
("smolvlm_vision", "smolvlm"),
|
||||
("chinese_clip_vision_model", "chinese_clip"),
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
|
@ -56,6 +56,8 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("aria", ("AriaImageProcessor")),
|
||||
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||
|
@ -32,6 +32,8 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("aimv2", "Aimv2Model"),
|
||||
("aimv2_vision_model", "Aimv2VisionModel"),
|
||||
("albert", "AlbertModel"),
|
||||
("align", "AlignModel"),
|
||||
("altclip", "AltCLIPModel"),
|
||||
@ -676,6 +678,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image mapping
|
||||
("aimv2_vision_model", "Aimv2VisionModel"),
|
||||
("beit", "BeitModel"),
|
||||
("bit", "BitModel"),
|
||||
("conditional_detr", "ConditionalDetrModel"),
|
||||
|
@ -45,6 +45,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("aimv2", "CLIPProcessor"),
|
||||
("align", "AlignProcessor"),
|
||||
("altclip", "AltCLIPProcessor"),
|
||||
("aria", "AriaProcessor"),
|
||||
|
@ -56,6 +56,13 @@ logger = logging.get_logger(__name__)
|
||||
# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
|
||||
TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
[
|
||||
(
|
||||
"aimv2",
|
||||
(
|
||||
"CLIPTokenizer",
|
||||
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"albert",
|
||||
(
|
||||
|
@ -475,7 +475,6 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_no_split_modules = [
|
||||
"SiglipTextEmbeddings",
|
||||
"SiglipEncoderLayer",
|
||||
"SiglipVisionEmbeddings",
|
||||
"SiglipEncoderLayer",
|
||||
"SiglipMultiheadAttentionPoolingHead",
|
||||
|
@ -708,7 +708,6 @@ class Siglip2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
_no_split_modules = [
|
||||
"Siglip2TextEmbeddings",
|
||||
"Siglip2EncoderLayer",
|
||||
"Siglip2VisionEmbeddings",
|
||||
"Siglip2EncoderLayer",
|
||||
"Siglip2MultiheadAttentionPoolingHead",
|
||||
|
0
tests/models/aimv2/__init__.py
Normal file
0
tests/models/aimv2/__init__.py
Normal file
742
tests/models/aimv2/test_modeling_aimv2.py
Normal file
742
tests/models/aimv2/test_modeling_aimv2.py
Normal file
@ -0,0 +1,742 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch AIMv2 model."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
random_attention_mask,
|
||||
require_torch_sdpa,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
Aimv2Model,
|
||||
Aimv2TextModel,
|
||||
Aimv2VisionModel,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoImageProcessor, AutoProcessor
|
||||
|
||||
|
||||
class Aimv2VisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=False,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return Aimv2VisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = Aimv2VisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class Aimv2ModelTesterMixin(ModelTesterMixin):
|
||||
"""
|
||||
Subclass of ModelTesterMixin with methods specific to testing Aimv2 models.
|
||||
The SDPA equivalence test is overridden here because Aimv2 models may have test/vision/text+vision inputs,
|
||||
different output logits, and are not supposed to be used or tested with padding_side="left".
|
||||
"""
|
||||
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# Load the model with SDPA
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# Load model with eager attention
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
if hasattr(model_sdpa, "vision_model"):
|
||||
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
|
||||
|
||||
if hasattr(model_sdpa, "text_model"):
|
||||
self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
|
||||
@require_torch
|
||||
class Aimv2VisionModelTest(Aimv2ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as Aimv2 does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (Aimv2VisionModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Aimv2VisionModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=Aimv2VisionConfig, has_text_modality=False, hidden_size=37
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="Aimv2 does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_model_get_set_embeddings(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
|
||||
class Aimv2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
if input_mask is not None:
|
||||
batch_size, seq_length = input_mask.shape
|
||||
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
input_mask[batch_idx, :start_index] = 1
|
||||
input_mask[batch_idx, start_index:] = 0
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask
|
||||
|
||||
def get_config(self):
|
||||
return Aimv2TextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = Aimv2TextModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Aimv2TextModelTest(Aimv2ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Aimv2TextModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Aimv2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Aimv2TextConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Aimv2 does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
|
||||
class Aimv2ModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=False):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = Aimv2TextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = Aimv2VisionModelTester(parent, **vision_kwargs)
|
||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return Aimv2Config.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = Aimv2Model(config).to(torch_device).eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, pixel_values, attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"return_loss": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
additional_model_inputs = ["pixel_values"]
|
||||
all_model_classes = (Aimv2Model,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Aimv2Model, "image-feature-extraction": Aimv2VisionModel}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Aimv2ModelTester(self)
|
||||
common_properties = ["projection_dim", "logit_scale_init_value"]
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=Aimv2Config, has_text_modality=False, common_properties=common_properties
|
||||
)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
print(config_and_inputs)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Aimv2Model does not have input/output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
# Override as the `logit_scale` parameter initialization is different for Aimv2
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
# check if `logit_scale` is initialized as per the original implementation
|
||||
if name == "logit_scale":
|
||||
self.assertAlmostEqual(
|
||||
param.data.item(),
|
||||
np.log(1 / 0.07),
|
||||
delta=1e-3,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save Aimv2Config and check if we can load Aimv2VisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = Aimv2VisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save Aimv2Config and check if we can load Aimv2TextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = Aimv2TextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
|
||||
dummy_input_ids = inputs_dict["input_ids"]
|
||||
|
||||
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(
|
||||
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
|
||||
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
|
||||
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
|
||||
dummy_input_ids = inputs_dict["input_ids"]
|
||||
dummy_pixel_mask = inputs_dict["attention_mask"]
|
||||
|
||||
# right padding
|
||||
dummy_pixel_mask[:] = 1
|
||||
dummy_pixel_mask[:, -1:] = 0
|
||||
|
||||
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(
|
||||
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
|
||||
)
|
||||
|
||||
logits_per_image_eager = outputs.logits_per_image[:, :-1]
|
||||
logits_per_text_eager = outputs.logits_per_text[:, :-1]
|
||||
|
||||
logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
|
||||
logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
|
||||
f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
|
||||
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
|
||||
)
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, *args):
|
||||
# Adding only flaky decorator here and call the parent test method
|
||||
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
||||
|
||||
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest._create_and_check_torchscript with CLIP->Aimv2
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
self.skipTest(reason="test_torchscript is set to False")
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
configs_no_init.return_dict = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
try:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
pixel_values = inputs_dict["pixel_values"] # Aimv2 needs pixel_values
|
||||
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
non_persistent_buffers = {}
|
||||
for key in loaded_model_state_dict.keys():
|
||||
if key not in model_state_dict.keys():
|
||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||
|
||||
loaded_model_state_dict = {
|
||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||
}
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class Aimv2ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model_name = "yaswanthgali/aimv2-large-patch14-224-lit-HF"
|
||||
model = Aimv2Model.from_pretrained(model_name, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Verify the logits
|
||||
self.assertEqual(
|
||||
outputs.logits_per_image.shape,
|
||||
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
|
||||
# handle device
|
||||
expected_logits = torch.tensor([[33.3550, 26.4255]]).to(model.device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class Aimv2VisionModelIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model_name = "yaswanthgali/aimv2-large-patch14-224-HF"
|
||||
|
||||
model = Aimv2VisionModel.from_pretrained(model_name, device_map="auto")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
inputs = processor(image, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
# Verify logits shape
|
||||
self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 256, 1024]))
|
||||
|
||||
# Verify logits slice
|
||||
# fmt: off
|
||||
expected_logits = torch.tensor(
|
||||
[[ 0.0510, 0.0806, -0.0990, -0.0154],
|
||||
[ 2.7850, -2.5143, -0.3320, 2.4196],
|
||||
[ 2.8179, -2.4089, -0.2770, 2.3218],
|
||||
[ 2.7641, -2.4114, -0.3684, 2.2998],
|
||||
[ 2.7972, -2.3180, -0.4490, 2.2302],
|
||||
[ 2.8584, -2.5322, -0.2302, 2.4936],
|
||||
[-2.7849, 2.4121, 1.3670, -1.5514]]).to(model.device)
|
||||
# fmt: on
|
||||
|
||||
output_slice = output.last_hidden_state.squeeze(0)[0:7, 0:4]
|
||||
self.assertTrue(torch.allclose(output_slice, expected_logits, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_for_native_resolution(self):
|
||||
model_name = "yaswanthgali/aimv2-large-patch14-native-HF"
|
||||
|
||||
model = Aimv2VisionModel.from_pretrained(model_name, device_map="auto")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
|
||||
image = image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
inputs = processor(image, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
# Verify logits shape
|
||||
self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 1530, 1024]))
|
||||
|
||||
# Verify logits slice
|
||||
# fmt: off
|
||||
expected_logits = torch.tensor(
|
||||
[[-1.3342, 0.3720, 0.0963, 0.4159],
|
||||
[-1.5328, 0.4677, 0.0936, 0.4321],
|
||||
[-0.3775, -0.2758, -0.0803, -0.5367],
|
||||
[-1.3877, 0.5561, -1.9064, -1.1766],
|
||||
[-0.5148, 0.0108, -0.4515, -0.6402],
|
||||
[-0.3400, -0.1711, -0.1855, -0.4219],
|
||||
[-1.2877, -0.0585, -0.1646, 0.7420]]).to(model.device)
|
||||
# fmt: on
|
||||
|
||||
output_slice = output.last_hidden_state.squeeze(0)[0:7, 0:4]
|
||||
self.assertTrue(torch.allclose(output_slice, expected_logits, atol=1e-3))
|
@ -196,6 +196,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"Aimv2TextModel",
|
||||
"AlignTextModel",
|
||||
"AlignVisionModel",
|
||||
"ClapTextModel",
|
||||
|
Loading…
Reference in New Issue
Block a user