mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Minor fixes post review
This commit is contained in:
parent
6277203fe0
commit
9af3764771
@ -19,7 +19,6 @@ rendered properly in your Markdown viewer.
|
||||
## Overview
|
||||
|
||||
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
@ -41,19 +40,14 @@ from transformers import AutoImageProcessor, AutoModel
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
"apple/aimv2-large-patch14-native",
|
||||
)
|
||||
model = AutoModel.from_pretrained(
|
||||
"apple/aimv2-large-patch14-native",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native")
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
Here is an example of checkpoint performing zero shot classification:
|
||||
Here is an example of a checkpoint performing zero-shot classification:
|
||||
|
||||
```python
|
||||
import requests
|
||||
@ -64,13 +58,8 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"apple/aimv2-large-patch14-224-lit",
|
||||
)
|
||||
model = AutoModel.from_pretrained(
|
||||
"apple/aimv2-large-patch14-224-lit",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||
|
||||
inputs = processor(
|
||||
images=image,
|
||||
|
@ -68,6 +68,8 @@ class AIMv2VisionConfig(PretrainedConfig):
|
||||
The standard deviation of the for initializing all weight matrices.
|
||||
use_head (`str`, *optional*, defaults to `True`):
|
||||
Whether to use Attention Pooling Head or Not.
|
||||
is_native (`str`, *optional*, defaults to `False`):
|
||||
Whether to use ckpt trained for image native resolution or not.
|
||||
Example:
|
||||
|
||||
```python
|
||||
@ -103,6 +105,7 @@ class AIMv2VisionConfig(PretrainedConfig):
|
||||
hidden_act: str = "silu",
|
||||
initializer_range: float = 0.02,
|
||||
use_head: bool = True,
|
||||
is_native: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -123,6 +126,7 @@ class AIMv2VisionConfig(PretrainedConfig):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.projection_dropout = projection_dropout
|
||||
self.is_native = is_native
|
||||
|
||||
|
||||
class AIMv2TextConfig(PretrainedConfig):
|
||||
|
@ -164,13 +164,14 @@ def write_model(
|
||||
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
|
||||
config.use_head = False
|
||||
|
||||
if hf_repo_id == "apple/aimv2-large-patch14-native":
|
||||
config.is_native = True
|
||||
|
||||
original_state_dict = load_original_state_dict(hf_repo_id)
|
||||
|
||||
print("Converting model...")
|
||||
|
||||
state_dict = {}
|
||||
# For `apple/aimv2-large-patch14-native` we don't have position_embedding in state_dict
|
||||
strict_loading = False
|
||||
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
|
||||
all_keys = list(original_state_dict.keys())
|
||||
|
||||
@ -187,19 +188,17 @@ def write_model(
|
||||
# Check if position embeddings exist before squeezing
|
||||
if new_key.endswith("position_embedding.weight"):
|
||||
state_dict[new_key] = value.squeeze(0)
|
||||
strict_loading = True
|
||||
|
||||
print(f"Loading the checkpoint in a {model_class.__name__}.")
|
||||
model = model_class(config)
|
||||
model.load_state_dict(state_dict, strict=strict_loading, assign=True)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
model = model_class.from_pretrained(output_dir, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
|
@ -52,6 +52,8 @@ logger = logging.get_logger(__name__)
|
||||
class AIMv2Output(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||
Contrastive loss for image-text similarity.
|
||||
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||
similarity scores.
|
||||
@ -64,14 +66,15 @@ class AIMv2Output(ModelOutput):
|
||||
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
|
||||
text_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`AIMv2TextModel`].
|
||||
vision_model_output (`BaseModelOutput`):
|
||||
vision_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`AIMv2VisionModel`].
|
||||
"""
|
||||
|
||||
logits_per_image: torch.FloatTensor = None
|
||||
logits_per_text: torch.FloatTensor = None
|
||||
text_embeds: torch.FloatTensor = None
|
||||
image_embeds: torch.FloatTensor = None
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits_per_image: Optional[torch.FloatTensor] = None
|
||||
logits_per_text: Optional[torch.FloatTensor] = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPooling = None
|
||||
|
||||
@ -133,7 +136,8 @@ class AIMv2VisionEmbeddings(nn.Module):
|
||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||
if not self.config.is_native:
|
||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
@ -158,7 +162,7 @@ class AIMv2VisionEmbeddings(nn.Module):
|
||||
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
||||
hidden_states = self.rms_norm(hidden_states)
|
||||
|
||||
if self.config.image_size != height or self.config.image_size != width:
|
||||
if self.config.is_native:
|
||||
pos_embed = self.build_2d_sincos_position_embedding(
|
||||
height // self.patch_size,
|
||||
width // self.patch_size,
|
||||
@ -506,6 +510,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
|
||||
elif hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, AIMv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
class AIMv2VisionModel(AIMv2PreTrainedModel):
|
||||
@ -516,6 +522,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
|
||||
self.config = config
|
||||
self.embeddings = AIMv2VisionEmbeddings(config)
|
||||
self.encoder = AIMv2Encoder(config)
|
||||
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
|
||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
self.use_head = config.use_head
|
||||
@ -722,7 +729,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
|
||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
||||
self.max_logit_scale = math.log(config.max_logit_scale)
|
||||
self.max_log_logit_scale = math.log(config.max_logit_scale)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@ -881,7 +888,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
|
||||
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||
|
||||
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp()
|
||||
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
|
||||
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
|
@ -16,8 +16,7 @@
|
||||
"""Pytorch implementation of AIMv2 Model"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -29,14 +28,13 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
|
||||
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
|
||||
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
from ..siglip.modeling_siglip import SiglipEncoder
|
||||
from ..siglip.modeling_siglip import SiglipEncoder, SiglipOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -84,6 +82,8 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
||||
The standard deviation of the for initializing all weight matrices.
|
||||
use_head (`str`, *optional*, defaults to `True`):
|
||||
Whether to use Attention Pooling Head or Not.
|
||||
is_native (`str`, *optional*, defaults to `False`):
|
||||
Whether to use ckpt trained for image native resolution or not.
|
||||
Example:
|
||||
|
||||
```python
|
||||
@ -116,6 +116,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
||||
hidden_act: str = "silu",
|
||||
initializer_range: float = 0.02,
|
||||
use_head: bool = True,
|
||||
is_native: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -138,6 +139,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.projection_dropout = projection_dropout
|
||||
self.is_native = is_native
|
||||
|
||||
del self.layer_norm_eps
|
||||
|
||||
@ -296,38 +298,8 @@ class AIMv2Config(SiglipConfig):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMv2Output(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||
similarity scores.
|
||||
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||
similarity scores.
|
||||
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`AIMv2TextModel`].
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
|
||||
text_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`AIMv2TextModel`].
|
||||
vision_model_output (`BaseModelOutput`):
|
||||
The output of the [`AIMv2VisionModel`].
|
||||
"""
|
||||
|
||||
logits_per_image: torch.FloatTensor = None
|
||||
logits_per_text: torch.FloatTensor = None
|
||||
text_embeds: torch.FloatTensor = None
|
||||
image_embeds: torch.FloatTensor = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPooling = None
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
return tuple(
|
||||
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||||
for k in self.keys()
|
||||
)
|
||||
class AIMv2Output(SiglipOutput):
|
||||
pass
|
||||
|
||||
|
||||
class AIMv2RMSNorm(LlamaRMSNorm):
|
||||
@ -364,7 +336,8 @@ class AIMv2VisionEmbeddings(nn.Module):
|
||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||
if not self.config.is_native:
|
||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
@ -389,7 +362,7 @@ class AIMv2VisionEmbeddings(nn.Module):
|
||||
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
||||
hidden_states = self.rms_norm(hidden_states)
|
||||
|
||||
if self.config.image_size != height or self.config.image_size != width:
|
||||
if self.config.is_native:
|
||||
pos_embed = self.build_2d_sincos_position_embedding(
|
||||
height // self.patch_size,
|
||||
width // self.patch_size,
|
||||
@ -580,6 +553,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
|
||||
elif hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, AIMv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
class AIMv2VisionModel(AIMv2PreTrainedModel):
|
||||
@ -590,6 +565,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
|
||||
self.config = config
|
||||
self.embeddings = AIMv2VisionEmbeddings(config)
|
||||
self.encoder = AIMv2Encoder(config)
|
||||
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
|
||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
self.use_head = config.use_head
|
||||
@ -716,7 +692,7 @@ class AIMv2Model(CLIPModel, nn.Module):
|
||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
||||
self.max_logit_scale = math.log(config.max_logit_scale)
|
||||
self.max_log_logit_scale = math.log(config.max_logit_scale)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@ -779,7 +755,7 @@ class AIMv2Model(CLIPModel, nn.Module):
|
||||
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||
|
||||
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp()
|
||||
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
|
||||
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
|
@ -56,6 +56,8 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("aria", ("AriaImageProcessor",)),
|
||||
("beit", ("BeitImageProcessor",)),
|
||||
|
@ -45,6 +45,8 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("aimv2", "CLIPProcessor"),
|
||||
("aimv2_vision_model", "CLIPProcessor"),
|
||||
("align", "AlignProcessor"),
|
||||
("altclip", "AltCLIPProcessor"),
|
||||
("aria", "AriaProcessor"),
|
||||
|
Loading…
Reference in New Issue
Block a user