From 9af376477101cc9588d9cee0fded91cb1b6478f7 Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 23 Apr 2025 20:41:56 +0530 Subject: [PATCH] Minor fixes post review --- docs/source/en/model_doc/aimv2.md | 21 ++----- .../models/aimv2/configuration_aimv2.py | 4 ++ .../convert_aimv2_original_pytorch_to_hf.py | 11 ++-- .../models/aimv2/modeling_aimv2.py | 25 ++++++--- .../models/aimv2/modular_aimv2.py | 56 ++++++------------- .../models/auto/image_processing_auto.py | 2 + .../models/auto/processing_auto.py | 2 + 7 files changed, 50 insertions(+), 71 deletions(-) diff --git a/docs/source/en/model_doc/aimv2.md b/docs/source/en/model_doc/aimv2.md index 51bf9b413ae..7db1c291bc4 100644 --- a/docs/source/en/model_doc/aimv2.md +++ b/docs/source/en/model_doc/aimv2.md @@ -19,7 +19,6 @@ rendered properly in your Markdown viewer. ## Overview The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby. - The abstract from the paper is the following: @@ -41,19 +40,14 @@ from transformers import AutoImageProcessor, AutoModel url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) -processor = AutoImageProcessor.from_pretrained( - "apple/aimv2-large-patch14-native", -) -model = AutoModel.from_pretrained( - "apple/aimv2-large-patch14-native", - trust_remote_code=True, -) +processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native") +model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native") inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) ``` -Here is an example of checkpoint performing zero shot classification: +Here is an example of a checkpoint performing zero-shot classification: ```python import requests @@ -64,13 +58,8 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."] -processor = AutoProcessor.from_pretrained( - "apple/aimv2-large-patch14-224-lit", -) -model = AutoModel.from_pretrained( - "apple/aimv2-large-patch14-224-lit", - trust_remote_code=True, -) +processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit") +model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit") inputs = processor( images=image, diff --git a/src/transformers/models/aimv2/configuration_aimv2.py b/src/transformers/models/aimv2/configuration_aimv2.py index a3498f371e3..5810880823e 100644 --- a/src/transformers/models/aimv2/configuration_aimv2.py +++ b/src/transformers/models/aimv2/configuration_aimv2.py @@ -68,6 +68,8 @@ class AIMv2VisionConfig(PretrainedConfig): The standard deviation of the for initializing all weight matrices. use_head (`str`, *optional*, defaults to `True`): Whether to use Attention Pooling Head or Not. + is_native (`str`, *optional*, defaults to `False`): + Whether to use ckpt trained for image native resolution or not. Example: ```python @@ -103,6 +105,7 @@ class AIMv2VisionConfig(PretrainedConfig): hidden_act: str = "silu", initializer_range: float = 0.02, use_head: bool = True, + is_native: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -123,6 +126,7 @@ class AIMv2VisionConfig(PretrainedConfig): self.qkv_bias = qkv_bias self.rms_norm_eps = rms_norm_eps self.projection_dropout = projection_dropout + self.is_native = is_native class AIMv2TextConfig(PretrainedConfig): diff --git a/src/transformers/models/aimv2/convert_aimv2_original_pytorch_to_hf.py b/src/transformers/models/aimv2/convert_aimv2_original_pytorch_to_hf.py index ce9e2540b13..d53796bc200 100644 --- a/src/transformers/models/aimv2/convert_aimv2_original_pytorch_to_hf.py +++ b/src/transformers/models/aimv2/convert_aimv2_original_pytorch_to_hf.py @@ -164,13 +164,14 @@ def write_model( if hf_repo_id != "apple/aimv2-large-patch14-224-lit": config.use_head = False + if hf_repo_id == "apple/aimv2-large-patch14-native": + config.is_native = True + original_state_dict = load_original_state_dict(hf_repo_id) print("Converting model...") state_dict = {} - # For `apple/aimv2-large-patch14-native` we don't have position_embedding in state_dict - strict_loading = False result = convert_old_keys_to_new_keys(original_state_dict, key_mapping) all_keys = list(original_state_dict.keys()) @@ -187,19 +188,17 @@ def write_model( # Check if position embeddings exist before squeezing if new_key.endswith("position_embedding.weight"): state_dict[new_key] = value.squeeze(0) - strict_loading = True print(f"Loading the checkpoint in a {model_class.__name__}.") model = model_class(config) - model.load_state_dict(state_dict, strict=strict_loading, assign=True) + model.load_state_dict(state_dict, strict=True, assign=True) print("Checkpoint loaded successfully.") print("Saving the model.") model.save_pretrained(output_dir, safe_serialization=safe_serialization) del state_dict, model - - # Safety check: reload the converted model gc.collect() + print("Reloading the model to check if it's saved correctly.") model = model_class.from_pretrained(output_dir, device_map="auto") print("Model reloaded successfully.") diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index bab3d06ad93..7515fc98372 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -52,6 +52,8 @@ logger = logging.get_logger(__name__) class AIMv2Output(ModelOutput): """ Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. @@ -64,14 +66,15 @@ class AIMv2Output(ModelOutput): The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`]. text_model_output (`BaseModelOutputWithPooling`): The output of the [`AIMv2TextModel`]. - vision_model_output (`BaseModelOutput`): + vision_model_output (`BaseModelOutputWithPooling`): The output of the [`AIMv2VisionModel`]. """ - logits_per_image: torch.FloatTensor = None - logits_per_text: torch.FloatTensor = None - text_embeds: torch.FloatTensor = None - image_embeds: torch.FloatTensor = None + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None @@ -133,7 +136,8 @@ class AIMv2VisionEmbeddings(nn.Module): self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) num_patches = (config.image_size // config.patch_size) ** 2 - self.position_embedding = nn.Embedding(num_patches, config.hidden_size) + if not self.config.is_native: + self.position_embedding = nn.Embedding(num_patches, config.hidden_size) self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) @staticmethod @@ -158,7 +162,7 @@ class AIMv2VisionEmbeddings(nn.Module): hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) hidden_states = self.rms_norm(hidden_states) - if self.config.image_size != height or self.config.image_size != width: + if self.config.is_native: pos_embed = self.build_2d_sincos_position_embedding( height // self.patch_size, width // self.patch_size, @@ -506,6 +510,8 @@ class AIMv2PreTrainedModel(PreTrainedModel): elif hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): module.logit_scale.data.fill_(math.log(1 / 0.07)) + elif isinstance(module, AIMv2AttentionPoolingHead): + module.cls_token.data.normal_(mean=0.0, std=std) class AIMv2VisionModel(AIMv2PreTrainedModel): @@ -516,6 +522,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel): self.config = config self.embeddings = AIMv2VisionEmbeddings(config) self.encoder = AIMv2Encoder(config) + # The only change from SiglipVisionTransformer is, layernorm -> rms_norm. self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) self.use_head = config.use_head @@ -722,7 +729,7 @@ class AIMv2Model(AIMv2PreTrainedModel): self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) - self.max_logit_scale = math.log(config.max_logit_scale) + self.max_log_logit_scale = math.log(config.max_logit_scale) self.post_init() @@ -881,7 +888,7 @@ class AIMv2Model(AIMv2PreTrainedModel): image_embeds = image_embeds / _get_vector_norm(image_embeds) text_embeds = text_embeds / _get_vector_norm(text_embeds) - logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp() + logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp() logits_per_text = (logit_scale * text_embeds) @ image_embeds.t() logits_per_image = logits_per_text.t() diff --git a/src/transformers/models/aimv2/modular_aimv2.py b/src/transformers/models/aimv2/modular_aimv2.py index f96ec8b9c59..3948454cc08 100644 --- a/src/transformers/models/aimv2/modular_aimv2.py +++ b/src/transformers/models/aimv2/modular_aimv2.py @@ -16,8 +16,7 @@ """Pytorch implementation of AIMv2 Model""" import math -from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import Callable, Optional, Tuple import torch import torch.nn.functional as F @@ -29,14 +28,13 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...activations import ACT2FN from ...utils import ( - ModelOutput, can_return_tuple, logging, ) from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig -from ..siglip.modeling_siglip import SiglipEncoder +from ..siglip.modeling_siglip import SiglipEncoder, SiglipOutput logger = logging.get_logger(__name__) @@ -84,6 +82,8 @@ class AIMv2VisionConfig(SiglipVisionConfig): The standard deviation of the for initializing all weight matrices. use_head (`str`, *optional*, defaults to `True`): Whether to use Attention Pooling Head or Not. + is_native (`str`, *optional*, defaults to `False`): + Whether to use ckpt trained for image native resolution or not. Example: ```python @@ -116,6 +116,7 @@ class AIMv2VisionConfig(SiglipVisionConfig): hidden_act: str = "silu", initializer_range: float = 0.02, use_head: bool = True, + is_native: bool = False, **kwargs, ): super().__init__( @@ -138,6 +139,7 @@ class AIMv2VisionConfig(SiglipVisionConfig): self.qkv_bias = qkv_bias self.rms_norm_eps = rms_norm_eps self.projection_dropout = projection_dropout + self.is_native = is_native del self.layer_norm_eps @@ -296,38 +298,8 @@ class AIMv2Config(SiglipConfig): pass -@dataclass -class AIMv2Output(ModelOutput): - """ - Args: - logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`AIMv2TextModel`]. - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`]. - text_model_output (`BaseModelOutputWithPooling`): - The output of the [`AIMv2TextModel`]. - vision_model_output (`BaseModelOutput`): - The output of the [`AIMv2VisionModel`]. - """ - - logits_per_image: torch.FloatTensor = None - logits_per_text: torch.FloatTensor = None - text_embeds: torch.FloatTensor = None - image_embeds: torch.FloatTensor = None - text_model_output: BaseModelOutputWithPooling = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> Tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) +class AIMv2Output(SiglipOutput): + pass class AIMv2RMSNorm(LlamaRMSNorm): @@ -364,7 +336,8 @@ class AIMv2VisionEmbeddings(nn.Module): self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) num_patches = (config.image_size // config.patch_size) ** 2 - self.position_embedding = nn.Embedding(num_patches, config.hidden_size) + if not self.config.is_native: + self.position_embedding = nn.Embedding(num_patches, config.hidden_size) self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) @staticmethod @@ -389,7 +362,7 @@ class AIMv2VisionEmbeddings(nn.Module): hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) hidden_states = self.rms_norm(hidden_states) - if self.config.image_size != height or self.config.image_size != width: + if self.config.is_native: pos_embed = self.build_2d_sincos_position_embedding( height // self.patch_size, width // self.patch_size, @@ -580,6 +553,8 @@ class AIMv2PreTrainedModel(PreTrainedModel): elif hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): module.logit_scale.data.fill_(math.log(1 / 0.07)) + elif isinstance(module, AIMv2AttentionPoolingHead): + module.cls_token.data.normal_(mean=0.0, std=std) class AIMv2VisionModel(AIMv2PreTrainedModel): @@ -590,6 +565,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel): self.config = config self.embeddings = AIMv2VisionEmbeddings(config) self.encoder = AIMv2Encoder(config) + # The only change from SiglipVisionTransformer is, layernorm -> rms_norm. self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps) self.use_head = config.use_head @@ -716,7 +692,7 @@ class AIMv2Model(CLIPModel, nn.Module): self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) - self.max_logit_scale = math.log(config.max_logit_scale) + self.max_log_logit_scale = math.log(config.max_logit_scale) self.post_init() @@ -779,7 +755,7 @@ class AIMv2Model(CLIPModel, nn.Module): image_embeds = image_embeds / _get_vector_norm(image_embeds) text_embeds = text_embeds / _get_vector_norm(text_embeds) - logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp() + logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp() logits_per_text = (logit_scale * text_embeds) @ image_embeds.t() logits_per_image = logits_per_text.t() diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 10ee95475ed..4acd5e9849f 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -56,6 +56,8 @@ if TYPE_CHECKING: else: IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( [ + ("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), ("aria", ("AriaImageProcessor",)), ("beit", ("BeitImageProcessor",)), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c55a4ab2129..489a5cbdd7c 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -45,6 +45,8 @@ logger = logging.get_logger(__name__) PROCESSOR_MAPPING_NAMES = OrderedDict( [ + ("aimv2", "CLIPProcessor"), + ("aimv2_vision_model", "CLIPProcessor"), ("align", "AlignProcessor"), ("altclip", "AltCLIPProcessor"), ("aria", "AriaProcessor"),