mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Adding 2D pooling for image embeddings
This commit is contained in:
parent
432c645e41
commit
65350cf531
@ -256,6 +256,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
|
||||
layer_norm_eps: float = 0.000001,
|
||||
vision_use_head: bool = False,
|
||||
torch_dtype: str = "bfloat16",
|
||||
pooled_seq_len: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -273,6 +274,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.pooled_seq_len = pooled_seq_len
|
||||
self.vision_use_head = vision_use_head
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ from typing import Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
@ -44,7 +45,7 @@ from ...utils import (
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..gemma import GemmaPreTrainedModel
|
||||
from ..siglip import SiglipVisionModel
|
||||
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig
|
||||
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig, Gemma3VisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -71,6 +72,28 @@ class Gemma3RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
class Gemma3VisionAvgPool2D(nn.Module):
|
||||
def __init__(self, config: Gemma3VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Applies average pooling on (B, width, width)
|
||||
to make it (B, final_width, final_width).
|
||||
"""
|
||||
batch_size, seq_len, channels = x.shape
|
||||
width = int(seq_len**0.5)
|
||||
if width * width != seq_len:
|
||||
raise ValueError(f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image.")
|
||||
final_width = int(self.config.pooled_seq_len**0.5)
|
||||
kernel_size = width // final_width
|
||||
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
|
||||
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Gemma3MultimodalInputProjection(nn.Module):
|
||||
def __init__(self, vision_dim: int, text_dim: int):
|
||||
super().__init__()
|
||||
@ -1012,7 +1035,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
text_config = self.config.text_config
|
||||
vision_config = self.config.vision_config
|
||||
@ -1028,10 +1050,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
vision_dim=vision_config.hidden_size, text_dim=text_config.hidden_size
|
||||
)
|
||||
self.mm_soft_emb_norm = Gemma3RMSNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps)
|
||||
|
||||
patches_per_image = vision_config.image_size // vision_config.patch_size
|
||||
avg_pool_k = patches_per_image**2 // text_config.mm_tokens_per_image
|
||||
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
|
||||
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
|
||||
self.vocab_size = text_config.vocab_size
|
||||
self.pad_token_id = pad_token_id if (pad_token_id := text_config.pad_token_id) is not None else -1
|
||||
self.post_init()
|
||||
@ -1076,12 +1095,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
||||
b, n, l = vision_outputs.shape
|
||||
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
|
||||
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
|
||||
pooled_vision_outputs = self.avg_pool(vision_outputs)
|
||||
image_features = self.encode_vision(pooled_vision_outputs)
|
||||
return image_features
|
||||
|
||||
|
@ -26,6 +26,7 @@ import PIL.Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
@ -332,6 +333,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
|
||||
layer_norm_eps: float = 0.000001,
|
||||
vision_use_head: bool = False,
|
||||
torch_dtype: str = "bfloat16",
|
||||
pooled_seq_len: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -349,6 +351,7 @@ class Gemma3VisionConfig(SiglipVisionConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.pooled_seq_len = pooled_seq_len
|
||||
self.vision_use_head = vision_use_head
|
||||
|
||||
|
||||
@ -710,6 +713,28 @@ class Gemma3Processor(ProcessorMixin):
|
||||
class Gemma3RMSNorm(GemmaRMSNorm):
|
||||
pass
|
||||
|
||||
class Gemma3VisionAvgPool2D(nn.Module):
|
||||
def __init__(self, config: Gemma3VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Applies average pooling on (B, width, width)
|
||||
to make it (B, final_width, final_width).
|
||||
"""
|
||||
batch_size, seq_len, channels = x.shape
|
||||
width = int(seq_len**0.5)
|
||||
if width * width != seq_len:
|
||||
raise ValueError(
|
||||
f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image."
|
||||
)
|
||||
final_width = int(self.config.pooled_seq_len**0.5)
|
||||
kernel_size = width//final_width
|
||||
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
|
||||
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
class Gemma3MultimodalInputProjection(nn.Module):
|
||||
|
||||
@ -1690,7 +1715,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
text_config = self.config.text_config
|
||||
vision_config = self.config.vision_config
|
||||
@ -1708,10 +1732,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
||||
vision_config.hidden_size, eps=vision_config.layer_norm_eps
|
||||
)
|
||||
|
||||
patches_per_image = vision_config.image_size // vision_config.patch_size
|
||||
avg_pool_k = patches_per_image ** 2 // text_config.mm_tokens_per_image
|
||||
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
|
||||
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
|
||||
self.vocab_size = text_config.vocab_size
|
||||
self.pad_token_id = (
|
||||
pad_token_id
|
||||
@ -1760,12 +1781,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
||||
b, n, l = vision_outputs.shape
|
||||
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
|
||||
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
|
||||
pooled_vision_outputs = self.avg_pool(vision_outputs)
|
||||
image_features = self.encode_vision(pooled_vision_outputs)
|
||||
return image_features
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user