Minor fixes post review

This commit is contained in:
yaswant19 2025-04-23 20:41:56 +05:30
parent 6277203fe0
commit 9af3764771
7 changed files with 50 additions and 71 deletions

View File

@ -19,7 +19,6 @@ rendered properly in your Markdown viewer.
## Overview ## Overview
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby. The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following: The abstract from the paper is the following:
@ -41,19 +40,14 @@ from transformers import AutoImageProcessor, AutoModel
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained( processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native")
"apple/aimv2-large-patch14-native", model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native")
)
model = AutoModel.from_pretrained(
"apple/aimv2-large-patch14-native",
trust_remote_code=True,
)
inputs = processor(images=image, return_tensors="pt") inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs) outputs = model(**inputs)
``` ```
Here is an example of checkpoint performing zero shot classification: Here is an example of a checkpoint performing zero-shot classification:
```python ```python
import requests import requests
@ -64,13 +58,8 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."] text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
"apple/aimv2-large-patch14-224-lit", model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit")
)
model = AutoModel.from_pretrained(
"apple/aimv2-large-patch14-224-lit",
trust_remote_code=True,
)
inputs = processor( inputs = processor(
images=image, images=image,

View File

@ -68,6 +68,8 @@ class AIMv2VisionConfig(PretrainedConfig):
The standard deviation of the for initializing all weight matrices. The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`): use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not. Whether to use Attention Pooling Head or Not.
is_native (`str`, *optional*, defaults to `False`):
Whether to use ckpt trained for image native resolution or not.
Example: Example:
```python ```python
@ -103,6 +105,7 @@ class AIMv2VisionConfig(PretrainedConfig):
hidden_act: str = "silu", hidden_act: str = "silu",
initializer_range: float = 0.02, initializer_range: float = 0.02,
use_head: bool = True, use_head: bool = True,
is_native: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -123,6 +126,7 @@ class AIMv2VisionConfig(PretrainedConfig):
self.qkv_bias = qkv_bias self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout self.projection_dropout = projection_dropout
self.is_native = is_native
class AIMv2TextConfig(PretrainedConfig): class AIMv2TextConfig(PretrainedConfig):

View File

@ -164,13 +164,14 @@ def write_model(
if hf_repo_id != "apple/aimv2-large-patch14-224-lit": if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
config.use_head = False config.use_head = False
if hf_repo_id == "apple/aimv2-large-patch14-native":
config.is_native = True
original_state_dict = load_original_state_dict(hf_repo_id) original_state_dict = load_original_state_dict(hf_repo_id)
print("Converting model...") print("Converting model...")
state_dict = {} state_dict = {}
# For `apple/aimv2-large-patch14-native` we don't have position_embedding in state_dict
strict_loading = False
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping) result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
all_keys = list(original_state_dict.keys()) all_keys = list(original_state_dict.keys())
@ -187,19 +188,17 @@ def write_model(
# Check if position embeddings exist before squeezing # Check if position embeddings exist before squeezing
if new_key.endswith("position_embedding.weight"): if new_key.endswith("position_embedding.weight"):
state_dict[new_key] = value.squeeze(0) state_dict[new_key] = value.squeeze(0)
strict_loading = True
print(f"Loading the checkpoint in a {model_class.__name__}.") print(f"Loading the checkpoint in a {model_class.__name__}.")
model = model_class(config) model = model_class(config)
model.load_state_dict(state_dict, strict=strict_loading, assign=True) model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.") print("Checkpoint loaded successfully.")
print("Saving the model.") print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization) model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model del state_dict, model
# Safety check: reload the converted model
gc.collect() gc.collect()
print("Reloading the model to check if it's saved correctly.") print("Reloading the model to check if it's saved correctly.")
model = model_class.from_pretrained(output_dir, device_map="auto") model = model_class.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.") print("Model reloaded successfully.")

View File

@ -52,6 +52,8 @@ logger = logging.get_logger(__name__)
class AIMv2Output(ModelOutput): class AIMv2Output(ModelOutput):
""" """
Args: Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores. similarity scores.
@ -64,14 +66,15 @@ class AIMv2Output(ModelOutput):
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`]. The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`): text_model_output (`BaseModelOutputWithPooling`):
The output of the [`AIMv2TextModel`]. The output of the [`AIMv2TextModel`].
vision_model_output (`BaseModelOutput`): vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`AIMv2VisionModel`]. The output of the [`AIMv2VisionModel`].
""" """
logits_per_image: torch.FloatTensor = None loss: Optional[torch.FloatTensor] = None
logits_per_text: torch.FloatTensor = None logits_per_image: Optional[torch.FloatTensor] = None
text_embeds: torch.FloatTensor = None logits_per_text: Optional[torch.FloatTensor] = None
image_embeds: torch.FloatTensor = None text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None
@ -133,6 +136,7 @@ class AIMv2VisionEmbeddings(nn.Module):
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2 num_patches = (config.image_size // config.patch_size) ** 2
if not self.config.is_native:
self.position_embedding = nn.Embedding(num_patches, config.hidden_size) self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
@ -158,7 +162,7 @@ class AIMv2VisionEmbeddings(nn.Module):
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
hidden_states = self.rms_norm(hidden_states) hidden_states = self.rms_norm(hidden_states)
if self.config.image_size != height or self.config.image_size != width: if self.config.is_native:
pos_embed = self.build_2d_sincos_position_embedding( pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size, height // self.patch_size,
width // self.patch_size, width // self.patch_size,
@ -506,6 +510,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
elif hasattr(module, "logit_scale"): elif hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter): if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07)) module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, AIMv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=std)
class AIMv2VisionModel(AIMv2PreTrainedModel): class AIMv2VisionModel(AIMv2PreTrainedModel):
@ -516,6 +522,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
self.config = config self.config = config
self.embeddings = AIMv2VisionEmbeddings(config) self.embeddings = AIMv2VisionEmbeddings(config)
self.encoder = AIMv2Encoder(config) self.encoder = AIMv2Encoder(config)
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.use_head = config.use_head self.use_head = config.use_head
@ -722,7 +729,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.max_logit_scale = math.log(config.max_logit_scale) self.max_log_logit_scale = math.log(config.max_logit_scale)
self.post_init() self.post_init()
@ -881,7 +888,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
image_embeds = image_embeds / _get_vector_norm(image_embeds) image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds) text_embeds = text_embeds / _get_vector_norm(text_embeds)
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp() logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t() logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
logits_per_image = logits_per_text.t() logits_per_image = logits_per_text.t()

View File

@ -16,8 +16,7 @@
"""Pytorch implementation of AIMv2 Model""" """Pytorch implementation of AIMv2 Model"""
import math import math
from dataclasses import dataclass from typing import Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -29,14 +28,13 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...activations import ACT2FN from ...activations import ACT2FN
from ...utils import ( from ...utils import (
ModelOutput,
can_return_tuple, can_return_tuple,
logging, logging,
) )
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from ..siglip.modeling_siglip import SiglipEncoder from ..siglip.modeling_siglip import SiglipEncoder, SiglipOutput
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -84,6 +82,8 @@ class AIMv2VisionConfig(SiglipVisionConfig):
The standard deviation of the for initializing all weight matrices. The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`): use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not. Whether to use Attention Pooling Head or Not.
is_native (`str`, *optional*, defaults to `False`):
Whether to use ckpt trained for image native resolution or not.
Example: Example:
```python ```python
@ -116,6 +116,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
hidden_act: str = "silu", hidden_act: str = "silu",
initializer_range: float = 0.02, initializer_range: float = 0.02,
use_head: bool = True, use_head: bool = True,
is_native: bool = False,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -138,6 +139,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
self.qkv_bias = qkv_bias self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout self.projection_dropout = projection_dropout
self.is_native = is_native
del self.layer_norm_eps del self.layer_norm_eps
@ -296,38 +298,8 @@ class AIMv2Config(SiglipConfig):
pass pass
@dataclass class AIMv2Output(SiglipOutput):
class AIMv2Output(ModelOutput): pass
"""
Args:
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`AIMv2TextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`AIMv2TextModel`].
vision_model_output (`BaseModelOutput`):
The output of the [`AIMv2VisionModel`].
"""
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class AIMv2RMSNorm(LlamaRMSNorm): class AIMv2RMSNorm(LlamaRMSNorm):
@ -364,6 +336,7 @@ class AIMv2VisionEmbeddings(nn.Module):
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2 num_patches = (config.image_size // config.patch_size) ** 2
if not self.config.is_native:
self.position_embedding = nn.Embedding(num_patches, config.hidden_size) self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
@ -389,7 +362,7 @@ class AIMv2VisionEmbeddings(nn.Module):
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
hidden_states = self.rms_norm(hidden_states) hidden_states = self.rms_norm(hidden_states)
if self.config.image_size != height or self.config.image_size != width: if self.config.is_native:
pos_embed = self.build_2d_sincos_position_embedding( pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size, height // self.patch_size,
width // self.patch_size, width // self.patch_size,
@ -580,6 +553,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
elif hasattr(module, "logit_scale"): elif hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter): if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07)) module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, AIMv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=std)
class AIMv2VisionModel(AIMv2PreTrainedModel): class AIMv2VisionModel(AIMv2PreTrainedModel):
@ -590,6 +565,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
self.config = config self.config = config
self.embeddings = AIMv2VisionEmbeddings(config) self.embeddings = AIMv2VisionEmbeddings(config)
self.encoder = AIMv2Encoder(config) self.encoder = AIMv2Encoder(config)
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.use_head = config.use_head self.use_head = config.use_head
@ -716,7 +692,7 @@ class AIMv2Model(CLIPModel, nn.Module):
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.max_logit_scale = math.log(config.max_logit_scale) self.max_log_logit_scale = math.log(config.max_logit_scale)
self.post_init() self.post_init()
@ -779,7 +755,7 @@ class AIMv2Model(CLIPModel, nn.Module):
image_embeds = image_embeds / _get_vector_norm(image_embeds) image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds) text_embeds = text_embeds / _get_vector_norm(text_embeds)
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp() logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t() logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
logits_per_image = logits_per_text.t() logits_per_image = logits_per_text.t()

View File

@ -56,6 +56,8 @@ if TYPE_CHECKING:
else: else:
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[ [
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("aria", ("AriaImageProcessor",)), ("aria", ("AriaImageProcessor",)),
("beit", ("BeitImageProcessor",)), ("beit", ("BeitImageProcessor",)),

View File

@ -45,6 +45,8 @@ logger = logging.get_logger(__name__)
PROCESSOR_MAPPING_NAMES = OrderedDict( PROCESSOR_MAPPING_NAMES = OrderedDict(
[ [
("aimv2", "CLIPProcessor"),
("aimv2_vision_model", "CLIPProcessor"),
("align", "AlignProcessor"), ("align", "AlignProcessor"),
("altclip", "AltCLIPProcessor"), ("altclip", "AltCLIPProcessor"),
("aria", "AriaProcessor"), ("aria", "AriaProcessor"),