This commit is contained in:
yaswant19 2025-03-29 12:55:59 +05:30
parent 923f76f2df
commit 3c2d124a77
5 changed files with 134 additions and 19 deletions

View File

@ -80,7 +80,6 @@ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
r"text_projector": r"text_projection",
r"log_logit_scale": r"logit_scale",
}
@ -162,8 +161,8 @@ def write_model(
config = config_class.from_pretrained(hf_repo_id)
# Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
config.vision_config.use_head = True
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
config.use_head = False
original_state_dict = load_original_state_dict(hf_repo_id)

View File

@ -158,7 +158,11 @@ class AIMv2VisionEmbeddings(nn.Module):
if self.config.image_size != height or self.config.image_size != width:
pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size, width // self.patch_size, embed_dim=self.config.hidden_size
height // self.patch_size,
width // self.patch_size,
embed_dim=self.config.hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
pos_embed = self.position_embedding(self.position_ids)

View File

@ -83,7 +83,8 @@ class AIMv2VisionConfig(SiglipVisionConfig):
The standard deviation of the for initializing all weight matrices.
use_head (`str`, *optional*, defaults to `True`):
Whether to use Attention Pooling Head or Not.
"""
"""
def __init__(
self,
hidden_size: int = 1024,
@ -98,9 +99,9 @@ class AIMv2VisionConfig(SiglipVisionConfig):
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
hidden_act: str ="silu",
initializer_range: float =0.02,
use_head: bool =True,
hidden_act: str = "silu",
initializer_range: float = 0.02,
use_head: bool = True,
is_causal: bool = False,
**kwargs,
):
@ -124,7 +125,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_causal=is_causal
self.is_causal = is_causal
del self.layer_norm_eps
@ -175,7 +176,8 @@ class AIMv2TextConfig(SiglipTextConfig):
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the for initializing all weight matrices.
"""
"""
def __init__(
self,
vocab_size: int = 49408,
@ -188,12 +190,12 @@ class AIMv2TextConfig(SiglipTextConfig):
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
hidden_act: str="silu",
pad_token_id: int=None,
bos_token_id: int=None,
hidden_act: str = "silu",
pad_token_id: int = None,
bos_token_id: int = None,
eos_token_id: int = 49407,
max_position_embeddings: int = 77,
initializer_range: bool=0.02,
initializer_range: bool = 0.02,
is_causal: bool = True,
**kwargs,
):
@ -217,7 +219,7 @@ class AIMv2TextConfig(SiglipTextConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_causal=is_causal
self.is_causal = is_causal
del self.bos_token_id
del self.pad_token_id
@ -270,6 +272,7 @@ class AIMv2Config(SiglipConfig):
>>> config = AIMv2Config.from_text_vision_configs(config_text, config_vision)
```"""
def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
):
@ -378,7 +381,11 @@ class AIMv2VisionEmbeddings(nn.Module):
if self.config.image_size != height or self.config.image_size != width:
pos_embed = self.build_2d_sincos_position_embedding(
height // self.patch_size, width // self.patch_size, embed_dim=self.config.hidden_size
height // self.patch_size,
width // self.patch_size,
embed_dim=self.config.hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
pos_embed = self.position_embedding(self.position_ids)
@ -408,7 +415,7 @@ class AIMv2Attention(nn.Module):
)
self.num_key_value_groups = 1
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)

View File

@ -19,6 +19,7 @@ import tempfile
import unittest
import numpy as np
import requests
from pytest import mark
from transformers import AIMv2Config, AIMv2TextConfig, AIMv2VisionConfig
@ -26,6 +27,7 @@ from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
@ -57,7 +59,9 @@ if is_torch_available():
if is_vision_available():
pass
from PIL import Image
from transformers import AutoImageProcessor, AutoProcessor
class AIMv2VisionModelTester:
@ -441,7 +445,7 @@ class AIMv2ModelTest(AIMv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
def test_model_get_set_embeddings(self):
pass
# override as the `logit_scale` parameter initialization is different for CLIP
# Override as the `logit_scale` parameter initialization is different for AIMv2
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -569,3 +573,103 @@ class AIMv2ModelTest(AIMv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
)
@require_vision
@require_torch
class AIMv2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model_name = "yaswanthgali/aimv2-large-patch14-224-lit-HF"
model = AIMv2Model.from_pretrained(model_name, device_map="auto")
processor = AutoProcessor.from_pretrained(model_name)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(
text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt"
).to(model.device)
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Verify the logits
self.assertEqual(
outputs.logits_per_image.shape,
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
)
self.assertEqual(
outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
# handle device
expected_logits = torch.tensor([[34.2415, 24.6724]]).to(model.device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
@require_vision
class AIMv2VisionModelIntegrationTests(unittest.TestCase):
@slow
def test_inference(self):
model_name = "yaswanthgali/aimv2-large-patch14-224-HF"
model = AIMv2VisionModel.from_pretrained(model_name, device_map="auto")
processor = AutoImageProcessor.from_pretrained(model_name)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(image, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model(**inputs)
# Verify logits shape
self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 256, 1024]))
# Verify logits slice
# fmt: off
expected_logits = torch.tensor(
[[ 0.0510, 0.0806, -0.0990, -0.0154],
[ 2.7850, -2.5143, -0.3320, 2.4196],
[ 2.8179, -2.4089, -0.2770, 2.3218],
[ 2.7641, -2.4114, -0.3684, 2.2998],
[ 2.7972, -2.3180, -0.4490, 2.2302],
[ 2.8584, -2.5322, -0.2302, 2.4936],
[-2.7849, 2.4121, 1.3670, -1.5514]]).to(model.device)
# fmt: on
output_slice = output.last_hidden_state.squeeze(0)[0:7, 0:4]
self.assertTrue(torch.allclose(output_slice, expected_logits, atol=1e-3))
@slow
def test_inference_for_native_resolution(self):
model_name = "yaswanthgali/aimv2-large-patch14-native-HF"
model = AIMv2VisionModel.from_pretrained(model_name, device_map="auto")
processor = AutoImageProcessor.from_pretrained(model_name)
image = image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
inputs = processor(image, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model(**inputs)
# Verify logits shape
self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 1530, 1024]))
# Verify logits slice
# fmt: off
expected_logits = torch.tensor(
[[-1.3342, 0.3720, 0.0963, 0.4159],
[-1.5328, 0.4677, 0.0936, 0.4321],
[-0.3775, -0.2758, -0.0803, -0.5367],
[-1.3877, 0.5561, -1.9064, -1.1766],
[-0.5148, 0.0108, -0.4515, -0.6402],
[-0.3400, -0.1711, -0.1855, -0.4219],
[-1.2877, -0.0585, -0.1646, 0.7420]]).to(model.device)
# fmt: on
output_slice = output.last_hidden_state.squeeze(0)[0:7, 0:4]
self.assertTrue(torch.allclose(output_slice, expected_logits, atol=1e-3))

View File

@ -176,6 +176,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"AIMv2TextModel",
"AlignTextModel",
"AlignVisionModel",
"ClapTextModel",