Adding 2D pooling for image embeddings

This commit is contained in:
Mayank Chaturvedi 2025-03-05 21:57:54 +00:00
parent 432c645e41
commit 65350cf531
3 changed files with 55 additions and 23 deletions

View File

@ -256,6 +256,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
layer_norm_eps: float = 0.000001,
vision_use_head: bool = False,
torch_dtype: str = "bfloat16",
pooled_seq_len: int = 256,
**kwargs,
):
super().__init__(
@ -273,6 +274,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
**kwargs,
)
self.pooled_seq_len = pooled_seq_len
self.vision_use_head = vision_use_head

View File

@ -25,6 +25,7 @@ from typing import Literal, Optional, Union, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
@ -44,7 +45,7 @@ from ...utils import (
from ...utils.deprecation import deprecate_kwarg
from ..gemma import GemmaPreTrainedModel
from ..siglip import SiglipVisionModel
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig, Gemma3VisionConfig
logger = logging.get_logger(__name__)
@ -71,6 +72,28 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Gemma3VisionAvgPool2D(nn.Module):
def __init__(self, config: Gemma3VisionConfig):
super().__init__()
self.config = config
def forward(self, x):
"""
Applies average pooling on (B, width, width)
to make it (B, final_width, final_width).
"""
batch_size, seq_len, channels = x.shape
width = int(seq_len**0.5)
if width * width != seq_len:
raise ValueError(f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image.")
final_width = int(self.config.pooled_seq_len**0.5)
kernel_size = width // final_width
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
x = x.flatten(2).transpose(1, 2)
return x
class Gemma3MultimodalInputProjection(nn.Module):
def __init__(self, vision_dim: int, text_dim: int):
super().__init__()
@ -1012,7 +1035,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.config = config
text_config = self.config.text_config
vision_config = self.config.vision_config
@ -1028,10 +1050,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
vision_dim=vision_config.hidden_size, text_dim=text_config.hidden_size
)
self.mm_soft_emb_norm = Gemma3RMSNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps)
patches_per_image = vision_config.image_size // vision_config.patch_size
avg_pool_k = patches_per_image**2 // text_config.mm_tokens_per_image
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
self.vocab_size = text_config.vocab_size
self.pad_token_id = pad_token_id if (pad_token_id := text_config.pad_token_id) is not None else -1
self.post_init()
@ -1076,12 +1095,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
b, n, l = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
pooled_vision_outputs = self.avg_pool(vision_outputs)
image_features = self.encode_vision(pooled_vision_outputs)
return image_features

View File

@ -26,6 +26,7 @@ import PIL.Image
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.nn.functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
@ -332,6 +333,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
layer_norm_eps: float = 0.000001,
vision_use_head: bool = False,
torch_dtype: str = "bfloat16",
pooled_seq_len: int = 256,
**kwargs,
):
super().__init__(
@ -349,6 +351,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
**kwargs,
)
self.pooled_seq_len = pooled_seq_len
self.vision_use_head = vision_use_head
@ -710,6 +713,28 @@ class Gemma3Processor(ProcessorMixin):
class Gemma3RMSNorm(GemmaRMSNorm):
pass
class Gemma3VisionAvgPool2D(nn.Module):
def __init__(self, config: Gemma3VisionConfig):
super().__init__()
self.config = config
def forward(self, x):
"""
Applies average pooling on (B, width, width)
to make it (B, final_width, final_width).
"""
batch_size, seq_len, channels = x.shape
width = int(seq_len**0.5)
if width * width != seq_len:
raise ValueError(
f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image."
)
final_width = int(self.config.pooled_seq_len**0.5)
kernel_size = width//final_width
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
x = x.flatten(2).transpose(1, 2)
return x
class Gemma3MultimodalInputProjection(nn.Module):
@ -1690,7 +1715,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.config = config
text_config = self.config.text_config
vision_config = self.config.vision_config
@ -1708,10 +1732,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
self.mm_soft_emb_norm = Gemma3RMSNorm(
vision_config.hidden_size, eps=vision_config.layer_norm_eps
)
patches_per_image = vision_config.image_size // vision_config.patch_size
avg_pool_k = patches_per_image ** 2 // text_config.mm_tokens_per_image
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
self.vocab_size = text_config.vocab_size
self.pad_token_id = (
pad_token_id
@ -1760,12 +1781,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
b, n, l = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
pooled_vision_outputs = self.avg_pool(vision_outputs)
image_features = self.encode_vision(pooled_vision_outputs)
return image_features