Minor fixes post review

This commit is contained in:
yaswant19 2025-04-23 20:41:56 +05:30
parent 6277203fe0
commit 9af3764771
7 changed files with 50 additions and 71 deletions

View File

@ -19,7 +19,6 @@ rendered properly in your Markdown viewer.
## Overview
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
@ -41,19 +40,14 @@ from transformers import AutoImageProcessor, AutoModel
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained(
"apple/aimv2-large-patch14-native",
)
model = AutoModel.from_pretrained(
"apple/aimv2-large-patch14-native",
trust_remote_code=True,
)
processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native")
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
```
Here is an example of checkpoint performing zero shot classification:
Here is an example of a checkpoint performing zero-shot classification:
```python
import requests
@ -64,13 +58,8 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
processor = AutoProcessor.from_pretrained(
"apple/aimv2-large-patch14-224-lit",
)
model = AutoModel.from_pretrained(
"apple/aimv2-large-patch14-224-lit",
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit")
inputs = processor(
images=image,

View File

@ -68,6 +68,8 @@ class AIMv2VisionConfig(PretrainedConfig):
The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not.
is_native (`str`, *optional*, defaults to `False`):
Whether to use ckpt trained for image native resolution or not.
Example:
```python
@ -103,6 +105,7 @@ class AIMv2VisionConfig(PretrainedConfig):
hidden_act: str = "silu",
initializer_range: float = 0.02,
use_head: bool = True,
is_native: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@ -123,6 +126,7 @@ class AIMv2VisionConfig(PretrainedConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_native = is_native
class AIMv2TextConfig(PretrainedConfig):

View File

@ -164,13 +164,14 @@ def write_model(
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
config.use_head = False
if hf_repo_id == "apple/aimv2-large-patch14-native":
config.is_native = True
original_state_dict = load_original_state_dict(hf_repo_id)
print("Converting model...")
state_dict = {}
# For `apple/aimv2-large-patch14-native` we don't have position_embedding in state_dict
strict_loading = False
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
all_keys = list(original_state_dict.keys())
@ -187,19 +188,17 @@ def write_model(
# Check if position embeddings exist before squeezing
if new_key.endswith("position_embedding.weight"):
state_dict[new_key] = value.squeeze(0)
strict_loading = True
print(f"Loading the checkpoint in a {model_class.__name__}.")
model = model_class(config)
model.load_state_dict(state_dict, strict=strict_loading, assign=True)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
model = model_class.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.")

View File

@ -52,6 +52,8 @@ logger = logging.get_logger(__name__)
class AIMv2Output(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
@ -64,14 +66,15 @@ class AIMv2Output(ModelOutput):
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`AIMv2TextModel`].
vision_model_output (`BaseModelOutput`):
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`AIMv2VisionModel`].
"""
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
loss: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
@ -133,7 +136,8 @@ class AIMv2VisionEmbeddings(nn.Module):
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
if not self.config.is_native:
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
@staticmethod
@ -158,7 +162,7 @@ class AIMv2VisionEmbeddings(nn.Module):
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
hidden_states = self.rms_norm(hidden_states)
if self.config.image_size != height or self.config.image_size != width:
if self.config.is_native:
pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size,
width // self.patch_size,
@ -506,6 +510,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
elif hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, AIMv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=std)
class AIMv2VisionModel(AIMv2PreTrainedModel):
@ -516,6 +522,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
self.config = config
self.embeddings = AIMv2VisionEmbeddings(config)
self.encoder = AIMv2Encoder(config)
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.use_head = config.use_head
@ -722,7 +729,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.max_logit_scale = math.log(config.max_logit_scale)
self.max_log_logit_scale = math.log(config.max_logit_scale)
self.post_init()
@ -881,7 +888,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds)
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp()
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
logits_per_image = logits_per_text.t()

View File

@ -16,8 +16,7 @@
"""Pytorch implementation of AIMv2 Model"""
import math
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple
from typing import Callable, Optional, Tuple
import torch
import torch.nn.functional as F
@ -29,14 +28,13 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...activations import ACT2FN
from ...utils import (
ModelOutput,
can_return_tuple,
logging,
)
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from ..siglip.modeling_siglip import SiglipEncoder
from ..siglip.modeling_siglip import SiglipEncoder, SiglipOutput
logger = logging.get_logger(__name__)
@ -84,6 +82,8 @@ class AIMv2VisionConfig(SiglipVisionConfig):
The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not.
is_native (`str`, *optional*, defaults to `False`):
Whether to use ckpt trained for image native resolution or not.
Example:
```python
@ -116,6 +116,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
hidden_act: str = "silu",
initializer_range: float = 0.02,
use_head: bool = True,
is_native: bool = False,
**kwargs,
):
super().__init__(
@ -138,6 +139,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_native = is_native
del self.layer_norm_eps
@ -296,38 +298,8 @@ class AIMv2Config(SiglipConfig):
pass
@dataclass
class AIMv2Output(ModelOutput):
"""
Args:
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`AIMv2TextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`AIMv2TextModel`].
vision_model_output (`BaseModelOutput`):
The output of the [`AIMv2VisionModel`].
"""
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class AIMv2Output(SiglipOutput):
pass
class AIMv2RMSNorm(LlamaRMSNorm):
@ -364,7 +336,8 @@ class AIMv2VisionEmbeddings(nn.Module):
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
if not self.config.is_native:
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
@staticmethod
@ -389,7 +362,7 @@ class AIMv2VisionEmbeddings(nn.Module):
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
hidden_states = self.rms_norm(hidden_states)
if self.config.image_size != height or self.config.image_size != width:
if self.config.is_native:
pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size,
width // self.patch_size,
@ -580,6 +553,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
elif hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, AIMv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=std)
class AIMv2VisionModel(AIMv2PreTrainedModel):
@ -590,6 +565,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
self.config = config
self.embeddings = AIMv2VisionEmbeddings(config)
self.encoder = AIMv2Encoder(config)
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.use_head = config.use_head
@ -716,7 +692,7 @@ class AIMv2Model(CLIPModel, nn.Module):
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.max_logit_scale = math.log(config.max_logit_scale)
self.max_log_logit_scale = math.log(config.max_logit_scale)
self.post_init()
@ -779,7 +755,7 @@ class AIMv2Model(CLIPModel, nn.Module):
image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds)
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp()
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
logits_per_image = logits_per_text.t()

View File

@ -56,6 +56,8 @@ if TYPE_CHECKING:
else:
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("aria", ("AriaImageProcessor",)),
("beit", ("BeitImageProcessor",)),

View File

@ -45,6 +45,8 @@ logger = logging.get_logger(__name__)
PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("aimv2", "CLIPProcessor"),
("aimv2_vision_model", "CLIPProcessor"),
("align", "AlignProcessor"),
("altclip", "AltCLIPProcessor"),
("aria", "AriaProcessor"),