mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Minor fixes post review
This commit is contained in:
parent
6277203fe0
commit
9af3764771
@ -19,7 +19,6 @@ rendered properly in your Markdown viewer.
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
|
The AIMv2 model was proposed in [Multimodal Autoregressive Pre-training of Large Vision Encoders](https://arxiv.org/abs/2411.14402) by Enrico Fini, Mustafa Shukor, Xiujun Li, Philipp Dufter, Michal Klein, David Haldimann, Sai Aitharaju, Victor Guilherme Turrisi da Costa, Louis Béthune, Zhe Gan, Alexander T Toshev, Marcin Eichner, Moin Nabi, Yinfei Yang, Joshua M. Susskind, Alaaeldin El-Nouby.
|
||||||
<INSERT SHORT SUMMARY HERE>
|
|
||||||
|
|
||||||
The abstract from the paper is the following:
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
@ -41,19 +40,14 @@ from transformers import AutoImageProcessor, AutoModel
|
|||||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
image = Image.open(requests.get(url, stream=True).raw)
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
processor = AutoImageProcessor.from_pretrained(
|
processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-native")
|
||||||
"apple/aimv2-large-patch14-native",
|
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-native")
|
||||||
)
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
"apple/aimv2-large-patch14-native",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = processor(images=image, return_tensors="pt")
|
inputs = processor(images=image, return_tensors="pt")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
```
|
```
|
||||||
|
|
||||||
Here is an example of checkpoint performing zero shot classification:
|
Here is an example of a checkpoint performing zero-shot classification:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import requests
|
import requests
|
||||||
@ -64,13 +58,8 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|||||||
image = Image.open(requests.get(url, stream=True).raw)
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
|
text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||||
"apple/aimv2-large-patch14-224-lit",
|
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224-lit")
|
||||||
)
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
"apple/aimv2-large-patch14-224-lit",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
images=image,
|
images=image,
|
||||||
|
@ -68,6 +68,8 @@ class AIMv2VisionConfig(PretrainedConfig):
|
|||||||
The standard deviation of the for initializing all weight matrices.
|
The standard deviation of the for initializing all weight matrices.
|
||||||
use_head (`str`, *optional*, defaults to `True`):
|
use_head (`str`, *optional*, defaults to `True`):
|
||||||
Whether to use Attention Pooling Head or Not.
|
Whether to use Attention Pooling Head or Not.
|
||||||
|
is_native (`str`, *optional*, defaults to `False`):
|
||||||
|
Whether to use ckpt trained for image native resolution or not.
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -103,6 +105,7 @@ class AIMv2VisionConfig(PretrainedConfig):
|
|||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
initializer_range: float = 0.02,
|
initializer_range: float = 0.02,
|
||||||
use_head: bool = True,
|
use_head: bool = True,
|
||||||
|
is_native: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -123,6 +126,7 @@ class AIMv2VisionConfig(PretrainedConfig):
|
|||||||
self.qkv_bias = qkv_bias
|
self.qkv_bias = qkv_bias
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.projection_dropout = projection_dropout
|
self.projection_dropout = projection_dropout
|
||||||
|
self.is_native = is_native
|
||||||
|
|
||||||
|
|
||||||
class AIMv2TextConfig(PretrainedConfig):
|
class AIMv2TextConfig(PretrainedConfig):
|
||||||
|
@ -164,13 +164,14 @@ def write_model(
|
|||||||
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
|
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
|
||||||
config.use_head = False
|
config.use_head = False
|
||||||
|
|
||||||
|
if hf_repo_id == "apple/aimv2-large-patch14-native":
|
||||||
|
config.is_native = True
|
||||||
|
|
||||||
original_state_dict = load_original_state_dict(hf_repo_id)
|
original_state_dict = load_original_state_dict(hf_repo_id)
|
||||||
|
|
||||||
print("Converting model...")
|
print("Converting model...")
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
# For `apple/aimv2-large-patch14-native` we don't have position_embedding in state_dict
|
|
||||||
strict_loading = False
|
|
||||||
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
|
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
|
||||||
all_keys = list(original_state_dict.keys())
|
all_keys = list(original_state_dict.keys())
|
||||||
|
|
||||||
@ -187,19 +188,17 @@ def write_model(
|
|||||||
# Check if position embeddings exist before squeezing
|
# Check if position embeddings exist before squeezing
|
||||||
if new_key.endswith("position_embedding.weight"):
|
if new_key.endswith("position_embedding.weight"):
|
||||||
state_dict[new_key] = value.squeeze(0)
|
state_dict[new_key] = value.squeeze(0)
|
||||||
strict_loading = True
|
|
||||||
|
|
||||||
print(f"Loading the checkpoint in a {model_class.__name__}.")
|
print(f"Loading the checkpoint in a {model_class.__name__}.")
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.load_state_dict(state_dict, strict=strict_loading, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
print("Checkpoint loaded successfully.")
|
print("Checkpoint loaded successfully.")
|
||||||
|
|
||||||
print("Saving the model.")
|
print("Saving the model.")
|
||||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||||
del state_dict, model
|
del state_dict, model
|
||||||
|
|
||||||
# Safety check: reload the converted model
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print("Reloading the model to check if it's saved correctly.")
|
print("Reloading the model to check if it's saved correctly.")
|
||||||
model = model_class.from_pretrained(output_dir, device_map="auto")
|
model = model_class.from_pretrained(output_dir, device_map="auto")
|
||||||
print("Model reloaded successfully.")
|
print("Model reloaded successfully.")
|
||||||
|
@ -52,6 +52,8 @@ logger = logging.get_logger(__name__)
|
|||||||
class AIMv2Output(ModelOutput):
|
class AIMv2Output(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
|
Contrastive loss for image-text similarity.
|
||||||
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||||
similarity scores.
|
similarity scores.
|
||||||
@ -64,14 +66,15 @@ class AIMv2Output(ModelOutput):
|
|||||||
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
|
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
|
||||||
text_model_output (`BaseModelOutputWithPooling`):
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`AIMv2TextModel`].
|
The output of the [`AIMv2TextModel`].
|
||||||
vision_model_output (`BaseModelOutput`):
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`AIMv2VisionModel`].
|
The output of the [`AIMv2VisionModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits_per_image: torch.FloatTensor = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits_per_text: torch.FloatTensor = None
|
logits_per_image: Optional[torch.FloatTensor] = None
|
||||||
text_embeds: torch.FloatTensor = None
|
logits_per_text: Optional[torch.FloatTensor] = None
|
||||||
image_embeds: torch.FloatTensor = None
|
text_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
image_embeds: Optional[torch.FloatTensor] = None
|
||||||
text_model_output: BaseModelOutputWithPooling = None
|
text_model_output: BaseModelOutputWithPooling = None
|
||||||
vision_model_output: BaseModelOutputWithPooling = None
|
vision_model_output: BaseModelOutputWithPooling = None
|
||||||
|
|
||||||
@ -133,6 +136,7 @@ class AIMv2VisionEmbeddings(nn.Module):
|
|||||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||||
|
|
||||||
num_patches = (config.image_size // config.patch_size) ** 2
|
num_patches = (config.image_size // config.patch_size) ** 2
|
||||||
|
if not self.config.is_native:
|
||||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||||
|
|
||||||
@ -158,7 +162,7 @@ class AIMv2VisionEmbeddings(nn.Module):
|
|||||||
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
||||||
hidden_states = self.rms_norm(hidden_states)
|
hidden_states = self.rms_norm(hidden_states)
|
||||||
|
|
||||||
if self.config.image_size != height or self.config.image_size != width:
|
if self.config.is_native:
|
||||||
pos_embed = self.build_2d_sincos_position_embedding(
|
pos_embed = self.build_2d_sincos_position_embedding(
|
||||||
height // self.patch_size,
|
height // self.patch_size,
|
||||||
width // self.patch_size,
|
width // self.patch_size,
|
||||||
@ -506,6 +510,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
|
|||||||
elif hasattr(module, "logit_scale"):
|
elif hasattr(module, "logit_scale"):
|
||||||
if isinstance(module.logit_scale, nn.Parameter):
|
if isinstance(module.logit_scale, nn.Parameter):
|
||||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||||
|
elif isinstance(module, AIMv2AttentionPoolingHead):
|
||||||
|
module.cls_token.data.normal_(mean=0.0, std=std)
|
||||||
|
|
||||||
|
|
||||||
class AIMv2VisionModel(AIMv2PreTrainedModel):
|
class AIMv2VisionModel(AIMv2PreTrainedModel):
|
||||||
@ -516,6 +522,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.embeddings = AIMv2VisionEmbeddings(config)
|
self.embeddings = AIMv2VisionEmbeddings(config)
|
||||||
self.encoder = AIMv2Encoder(config)
|
self.encoder = AIMv2Encoder(config)
|
||||||
|
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
|
||||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||||
|
|
||||||
self.use_head = config.use_head
|
self.use_head = config.use_head
|
||||||
@ -722,7 +729,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
|
|||||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||||
|
|
||||||
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
||||||
self.max_logit_scale = math.log(config.max_logit_scale)
|
self.max_log_logit_scale = math.log(config.max_logit_scale)
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@ -881,7 +888,7 @@ class AIMv2Model(AIMv2PreTrainedModel):
|
|||||||
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
||||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||||
|
|
||||||
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp()
|
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
|
||||||
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
||||||
logits_per_image = logits_per_text.t()
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
|
@ -16,8 +16,7 @@
|
|||||||
"""Pytorch implementation of AIMv2 Model"""
|
"""Pytorch implementation of AIMv2 Model"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from typing import Callable, Optional, Tuple
|
||||||
from typing import Any, Callable, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -29,14 +28,13 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
|
||||||
can_return_tuple,
|
can_return_tuple,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
|
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
|
||||||
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
|
from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
|
||||||
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||||
from ..siglip.modeling_siglip import SiglipEncoder
|
from ..siglip.modeling_siglip import SiglipEncoder, SiglipOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -84,6 +82,8 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
|||||||
The standard deviation of the for initializing all weight matrices.
|
The standard deviation of the for initializing all weight matrices.
|
||||||
use_head (`str`, *optional*, defaults to `True`):
|
use_head (`str`, *optional*, defaults to `True`):
|
||||||
Whether to use Attention Pooling Head or Not.
|
Whether to use Attention Pooling Head or Not.
|
||||||
|
is_native (`str`, *optional*, defaults to `False`):
|
||||||
|
Whether to use ckpt trained for image native resolution or not.
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -116,6 +116,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
|||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
initializer_range: float = 0.02,
|
initializer_range: float = 0.02,
|
||||||
use_head: bool = True,
|
use_head: bool = True,
|
||||||
|
is_native: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -138,6 +139,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
|||||||
self.qkv_bias = qkv_bias
|
self.qkv_bias = qkv_bias
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.projection_dropout = projection_dropout
|
self.projection_dropout = projection_dropout
|
||||||
|
self.is_native = is_native
|
||||||
|
|
||||||
del self.layer_norm_eps
|
del self.layer_norm_eps
|
||||||
|
|
||||||
@ -296,38 +298,8 @@ class AIMv2Config(SiglipConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class AIMv2Output(SiglipOutput):
|
||||||
class AIMv2Output(ModelOutput):
|
pass
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
|
||||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
|
||||||
similarity scores.
|
|
||||||
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
|
||||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
|
||||||
similarity scores.
|
|
||||||
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`AIMv2TextModel`].
|
|
||||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
|
||||||
The image embeddings obtained by applying the projection layer to the pooled output of [`AIMv2VisionModel`].
|
|
||||||
text_model_output (`BaseModelOutputWithPooling`):
|
|
||||||
The output of the [`AIMv2TextModel`].
|
|
||||||
vision_model_output (`BaseModelOutput`):
|
|
||||||
The output of the [`AIMv2VisionModel`].
|
|
||||||
"""
|
|
||||||
|
|
||||||
logits_per_image: torch.FloatTensor = None
|
|
||||||
logits_per_text: torch.FloatTensor = None
|
|
||||||
text_embeds: torch.FloatTensor = None
|
|
||||||
image_embeds: torch.FloatTensor = None
|
|
||||||
text_model_output: BaseModelOutputWithPooling = None
|
|
||||||
vision_model_output: BaseModelOutputWithPooling = None
|
|
||||||
|
|
||||||
def to_tuple(self) -> Tuple[Any]:
|
|
||||||
return tuple(
|
|
||||||
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
|
||||||
for k in self.keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AIMv2RMSNorm(LlamaRMSNorm):
|
class AIMv2RMSNorm(LlamaRMSNorm):
|
||||||
@ -364,6 +336,7 @@ class AIMv2VisionEmbeddings(nn.Module):
|
|||||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||||
|
|
||||||
num_patches = (config.image_size // config.patch_size) ** 2
|
num_patches = (config.image_size // config.patch_size) ** 2
|
||||||
|
if not self.config.is_native:
|
||||||
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
|
||||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||||
|
|
||||||
@ -389,7 +362,7 @@ class AIMv2VisionEmbeddings(nn.Module):
|
|||||||
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
|
||||||
hidden_states = self.rms_norm(hidden_states)
|
hidden_states = self.rms_norm(hidden_states)
|
||||||
|
|
||||||
if self.config.image_size != height or self.config.image_size != width:
|
if self.config.is_native:
|
||||||
pos_embed = self.build_2d_sincos_position_embedding(
|
pos_embed = self.build_2d_sincos_position_embedding(
|
||||||
height // self.patch_size,
|
height // self.patch_size,
|
||||||
width // self.patch_size,
|
width // self.patch_size,
|
||||||
@ -580,6 +553,8 @@ class AIMv2PreTrainedModel(PreTrainedModel):
|
|||||||
elif hasattr(module, "logit_scale"):
|
elif hasattr(module, "logit_scale"):
|
||||||
if isinstance(module.logit_scale, nn.Parameter):
|
if isinstance(module.logit_scale, nn.Parameter):
|
||||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||||
|
elif isinstance(module, AIMv2AttentionPoolingHead):
|
||||||
|
module.cls_token.data.normal_(mean=0.0, std=std)
|
||||||
|
|
||||||
|
|
||||||
class AIMv2VisionModel(AIMv2PreTrainedModel):
|
class AIMv2VisionModel(AIMv2PreTrainedModel):
|
||||||
@ -590,6 +565,7 @@ class AIMv2VisionModel(AIMv2PreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.embeddings = AIMv2VisionEmbeddings(config)
|
self.embeddings = AIMv2VisionEmbeddings(config)
|
||||||
self.encoder = AIMv2Encoder(config)
|
self.encoder = AIMv2Encoder(config)
|
||||||
|
# The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
|
||||||
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
self.rms_norm = AIMv2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||||
|
|
||||||
self.use_head = config.use_head
|
self.use_head = config.use_head
|
||||||
@ -716,7 +692,7 @@ class AIMv2Model(CLIPModel, nn.Module):
|
|||||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||||
|
|
||||||
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
||||||
self.max_logit_scale = math.log(config.max_logit_scale)
|
self.max_log_logit_scale = math.log(config.max_logit_scale)
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@ -779,7 +755,7 @@ class AIMv2Model(CLIPModel, nn.Module):
|
|||||||
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
||||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||||
|
|
||||||
logit_scale = self.logit_scale.clamp(0.0, self.max_logit_scale).exp()
|
logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
|
||||||
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
|
||||||
logits_per_image = logits_per_text.t()
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
|
@ -56,6 +56,8 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||||
|
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||||
("aria", ("AriaImageProcessor",)),
|
("aria", ("AriaImageProcessor",)),
|
||||||
("beit", ("BeitImageProcessor",)),
|
("beit", ("BeitImageProcessor",)),
|
||||||
|
@ -45,6 +45,8 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("aimv2", "CLIPProcessor"),
|
||||||
|
("aimv2_vision_model", "CLIPProcessor"),
|
||||||
("align", "AlignProcessor"),
|
("align", "AlignProcessor"),
|
||||||
("altclip", "AltCLIPProcessor"),
|
("altclip", "AltCLIPProcessor"),
|
||||||
("aria", "AriaProcessor"),
|
("aria", "AriaProcessor"),
|
||||||
|
Loading…
Reference in New Issue
Block a user