From 7478568172716393a16518512b06262c52da5cf8 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 26 Jun 2025 17:19:49 +0200 Subject: [PATCH] push current change --- .../models/gemma3n/configuration_gemma3n.py | 81 ++-- .../models/gemma3n/convert_gemma3n_weights.py | 2 +- .../models/gemma3n/modeling_gemma3n.py | 24 +- .../models/gemma3n/modular_gemma3n.py | 360 +++++++++--------- 4 files changed, 237 insertions(+), 230 deletions(-) diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index ca1a0671774..d080ea7eca2 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -325,33 +325,24 @@ class Gemma3nAudioConfig(PretrainedConfig): The epsilon used by the rms normalization layers. gradient_clipping (`float`, *optional*, defaults to 10000000000.0): Clipping value used to stablize extremely large gradient values. - conf_attention_chunk_size (`int`, *optional*, defaults to 12): - The sub-sequence size for local attention processing inside the Conformer ("conf") section of the - Universal Speech Model. - conf_attention_context_left (`int`, *optional*, defaults to 13): - The left context size of the local attention inside the Conformer ("conf") section of the - Universal Speech Model. - conf_attention_context_right (`int`, *optional*, defaults to 0): - The right context size of the local attention inside the Conformer ("conf") section of the - Universal Speech Model. - conf_attention_logit_cap (`float`, *optional*, defaults to 50.0): - Logit cap applied during local attention inside the Conformer ("conf") section of the - Universal Speech Model. - conf_num_attention_heads (`int`, *optional*, defaults to 8): - The number of attention heads in local attention inside the Conformer ("conf") section of the - Universal Speech Model. - conf_num_hidden_layers (`int`, *optional*, defaults to 12): - The number of layers that use local attention inside the Conformer ("conf") section of the - Universal Speech Model. - conf_conv_kernel_size (`int`, *optional*, defaults to 5): - Convolution kernel size for the conformer block inside the Conformer ("conf") section of the - Universal Speech Model. - conf_reduction_factor (`int`, *optional*, defaults to 4): - Reduction factor used in the conformer block inside the Conformer ("conf") section of the - Universal Speech Model. - conf_residual_weight (`float`, *optional*, defaults to 0.5): - Residual connection weight inside the Conformer ("conf") section of the - Universal Speech Model. + attention_chunk_size (`int`, *optional*, defaults to 12): + The sub-sequence size for local attention processing. + attention_context_left (`int`, *optional*, defaults to 13): + The left context size of the local attention. + attention_context_right (`int`, *optional*, defaults to 0): + The right context size of the local attention. + attention_logit_cap (`float`, *optional*, defaults to 50.0): + Logit cap applied during local attention. + num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in local attention. + num_hidden_layers (`int`, *optional*, defaults to 12): + The number of layers that use local attention. + conv_kernel_size (`int`, *optional*, defaults to 5): + Convolution kernel size for the conformer block. + reduction_factor (`int`, *optional*, defaults to 4): + Reduction factor used in the conformer block. + residual_weight (`float`, *optional*, defaults to 0.5): + Residual connection weight. sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`): The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection ("sscp") section of the Universal Speech Model. @@ -395,15 +386,15 @@ class Gemma3nAudioConfig(PretrainedConfig): hidden_size: int = 1536, rms_norm_eps: float = 1e-6, gradient_clipping: float = 10_000_000_000.0, - conf_attention_chunk_size: int = 12, - conf_attention_context_left: int = 13, - conf_attention_context_right: int = 0, - conf_attention_logit_cap: float = 50.0, - conf_num_attention_heads: int = 8, - conf_num_hidden_layers: int = 12, - conf_conv_kernel_size: int = 5, - conf_reduction_factor: int = 4, - conf_residual_weight: float = 0.5, + attention_chunk_size: int = 12, + attention_context_left: int = 13, + attention_context_right: int = 0, + attention_logit_cap: float = 50.0, + num_attention_heads: int = 8, + num_hidden_layers: int = 12, + conv_kernel_size: int = 5, + reduction_factor: int = 4, + residual_weight: float = 0.5, sscp_conv_channel_size: tuple[int, int] = (128, 32), sscp_conv_group_norm_eps: float = 1e-3, sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = ( @@ -423,15 +414,15 @@ class Gemma3nAudioConfig(PretrainedConfig): self.vocab_size = vocab_size self.vocab_offset = vocab_offset self.gradient_clipping = gradient_clipping - self.conf_attention_chunk_size = conf_attention_chunk_size - self.conf_attention_context_left = conf_attention_context_left - self.conf_attention_context_right = conf_attention_context_right - self.conf_attention_logit_cap = conf_attention_logit_cap - self.conf_num_attention_heads = conf_num_attention_heads - self.conf_num_hidden_layers = conf_num_hidden_layers - self.conf_conv_kernel_size = conf_conv_kernel_size - self.conf_reduction_factor = conf_reduction_factor - self.conf_residual_weight = conf_residual_weight + self.attention_chunk_size = attention_chunk_size + self.attention_context_left = attention_context_left + self.attention_context_right = attention_context_right + self.attention_logit_cap = attention_logit_cap + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.conv_kernel_size = conv_kernel_size + self.reduction_factor = reduction_factor + self.residual_weight = residual_weight self.sscp_conv_channel_size = sscp_conv_channel_size self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps self.sscp_conv_kernel_size = sscp_conv_kernel_size diff --git a/src/transformers/models/gemma3n/convert_gemma3n_weights.py b/src/transformers/models/gemma3n/convert_gemma3n_weights.py index 2f25ca56d46..71187a2b01e 100644 --- a/src/transformers/models/gemma3n/convert_gemma3n_weights.py +++ b/src/transformers/models/gemma3n/convert_gemma3n_weights.py @@ -240,7 +240,7 @@ def convert_audio_encoder_weights( converted_weights: list[Any] = [] if path.startswith(_AUDIO_ENCODER_CONFORMER): - assert weights.shape[0] == config.conf_num_hidden_layers + assert weights.shape[0] == config.num_hidden_layers for i, matrix in enumerate(weights): if "fflayer_end" in path: diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 0817e16451a..437edc04211 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -149,11 +149,11 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module): super().__init__() self.config = config - self.num_heads = self.config.conf_num_attention_heads + self.num_heads = self.config.num_attention_heads self.channels = self.config.hidden_size self.head_dim = self.channels // self.num_heads - self.max_backward = max(0, self.config.conf_attention_context_left - 1) - self.max_forward = self.config.conf_attention_context_right + self.max_backward = max(0, self.config.attention_context_left - 1) + self.max_forward = self.config.attention_context_right self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False) @@ -319,7 +319,7 @@ class Gemma3nAudioAttention(nn.Module): super().__init__() self.config = config - self.num_heads = self.config.conf_num_attention_heads + self.num_heads = self.config.num_attention_heads self.hidden_size = self.config.hidden_size self.head_dim = self.hidden_size // self.num_heads @@ -826,7 +826,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module): self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False) self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False) self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) - self.post_layer_scale = torch.tensor(self.config.conf_residual_weight) + self.post_layer_scale = torch.tensor(self.config.residual_weight) def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: residual = audio_encodings @@ -850,7 +850,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module): self.depthwise_conv1d = nn.Conv1d( in_channels=self.config.hidden_size, out_channels=self.config.hidden_size, - kernel_size=self.config.conf_conv_kernel_size, + kernel_size=self.config.conv_kernel_size, stride=1, padding=0, # Manual causal padding groups=self.config.hidden_size, # Depthwise @@ -860,7 +860,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module): self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False) - self.causal_padding = self.config.conf_conv_kernel_size - 1 + self.causal_padding = self.config.conv_kernel_size - 1 def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: audio_encodings_residual = audio_encodings # Save for residual connection @@ -922,9 +922,7 @@ class Gemma3nAudioEncoder(PreTrainedModel): self.config = config self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config) - self.conformer = nn.ModuleList( - [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)] - ) + self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)]) def forward( self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor @@ -973,10 +971,10 @@ class Gemma3nAudioEncoder(PreTrainedModel): for block in self.conformer: audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask - if self.config.conf_reduction_factor > 1: - audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] + if self.config.reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.config.reduction_factor] # Reduce the mask as well - current_mask = current_mask[:, :: self.config.conf_reduction_factor] + current_mask = current_mask[:, :: self.config.reduction_factor] audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) return audio_encodings, current_mask diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a3ffa710d84..646b4f830a6 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -21,6 +21,7 @@ from typing import Any, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F +from transformers.models.phi4_multimodal.modeling_phi4_multimodal import unfold_tensor from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, HybridCache @@ -407,15 +408,15 @@ class Gemma3nAudioConfig(PretrainedConfig): hidden_size: int = 1536, rms_norm_eps: float = 1e-6, gradient_clipping: float = 10_000_000_000.0, - conf_attention_chunk_size: int = 12, - conf_attention_context_left: int = 13, - conf_attention_context_right: int = 0, - conf_attention_logit_cap: float = 50.0, - conf_num_attention_heads: int = 8, - conf_num_hidden_layers: int = 12, - conf_conv_kernel_size: int = 5, - conf_reduction_factor: int = 4, - conf_residual_weight: float = 0.5, + attention_chunk_size: int = 12, + attention_context_left: int = 13, + attention_context_right: int = 0, + attention_logit_cap: float = 50.0, + num_attention_heads: int = 8, + num_hidden_layers: int = 12, + conv_kernel_size: int = 5, + reduction_factor: int = 4, + residual_weight: float = 0.5, sscp_conv_channel_size: tuple[int, int] = (128, 32), sscp_conv_group_norm_eps: float = 1e-3, sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = ( @@ -435,15 +436,15 @@ class Gemma3nAudioConfig(PretrainedConfig): self.vocab_size = vocab_size self.vocab_offset = vocab_offset self.gradient_clipping = gradient_clipping - self.conf_attention_chunk_size = conf_attention_chunk_size - self.conf_attention_context_left = conf_attention_context_left - self.conf_attention_context_right = conf_attention_context_right - self.conf_attention_logit_cap = conf_attention_logit_cap - self.conf_num_attention_heads = conf_num_attention_heads - self.conf_num_hidden_layers = conf_num_hidden_layers - self.conf_conv_kernel_size = conf_conv_kernel_size - self.conf_reduction_factor = conf_reduction_factor - self.conf_residual_weight = conf_residual_weight + self.attention_chunk_size = attention_chunk_size + self.attention_context_left = attention_context_left + self.attention_context_right = attention_context_right + self.attention_logit_cap = attention_logit_cap + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.conv_kernel_size = conv_kernel_size + self.reduction_factor = reduction_factor + self.residual_weight = residual_weight self.sscp_conv_channel_size = sscp_conv_channel_size self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps self.sscp_conv_kernel_size = sscp_conv_kernel_size @@ -706,16 +707,77 @@ class Gemma3nRMSNorm(Gemma3RMSNorm): # ==== Audio Encoder ==== +def _relative_shift( + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, +) -> torch.Tensor: + """Performs the relative shift. + + Args: + term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size + (B), num_heads (N), num_query_blocks (U), query_block_size (W), + key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1). + + Returns: + Tensor of shape [B, N, U, W, C]. + """ + # term_bd_before_shift shape: [B, N, U, W, F_span] + # Target shape after shift: [B, N, U, W, C] + + # Padding amount for the last dimension (F_span) to become (C + 1) + # C = key_context_size + # F_span = max_span_plus_1 + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + + # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...) + # We only pad the last dimension on the right. + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple) + # Shape after pad: [B, N, U, W, C+1] + + # Reshape for slicing (emulating JAX's behavior) + # [B, N, U, W * (C+1)] + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + + # Slice to effective [B, N, U, W * C] + term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size] + + # Reshape back to [B, N, U, W, C] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + class Gemma3nAudioRelativePositionEmbedding(nn.Module): def __init__(self, config: Gemma3nAudioConfig): super().__init__() self.config = config - self.num_heads = self.config.conf_num_attention_heads + self.num_heads = self.config.num_attention_heads self.channels = self.config.hidden_size self.head_dim = self.channels // self.num_heads - self.max_backward = max(0, self.config.conf_attention_context_left - 1) - self.max_forward = self.config.conf_attention_context_right + self.max_backward = max(0, self.config.attention_context_left - 1) + self.max_forward = self.config.attention_context_right self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False) @@ -736,144 +798,82 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module): timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) return timing_signal.type(dtype) - def _relative_shift( - self, - term_bd_before_shift: torch.Tensor, - batch_size: int, - num_heads: int, - num_query_blocks: int, - query_block_size: int, - key_context_size: int, - max_span_plus_1: int, - ) -> torch.Tensor: - """Performs the relative shift. - - Args: - term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size - (B), num_heads (N), num_query_blocks (U), query_block_size (W), - key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1). - - Returns: - Tensor of shape [B, N, U, W, C]. - """ - # term_bd_before_shift shape: [B, N, U, W, F_span] - # Target shape after shift: [B, N, U, W, C] - - # Padding amount for the last dimension (F_span) to become (C + 1) - # C = key_context_size - # F_span = max_span_plus_1 - pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 - - # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...) - # We only pad the last dimension on the right. - padding_tuple = (0, pad_amount_last_dim) - - term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple) - # Shape after pad: [B, N, U, W, C+1] - - # Reshape for slicing (emulating JAX's behavior) - # [B, N, U, W * (C+1)] - term_bd_reshaped = term_bd_padded.reshape( - ( - batch_size, - num_heads, - num_query_blocks, - query_block_size * (key_context_size + 1), - ) - ) - - # Slice to effective [B, N, U, W * C] - term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size] - - # Reshape back to [B, N, U, W, C] - term_bd_shifted = term_bd_sliced.reshape( - ( - batch_size, - num_heads, - num_query_blocks, - query_block_size, - key_context_size, - ) - ) - return term_bd_shifted - - def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: - # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim) - # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim) - # C = W + L + R (key_context_size) - # F_span = L + R + 1 (max_span + 1) - - batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape - _, _, key_context_size, _, _ = keys.shape - + def forward(self) -> torch.Tensor: # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R] # Length is L+R+1 = self.max_span + 1 - pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze( + pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=self.device).unsqueeze( 0 ) # Shape [1, F_span] max_span_plus_1 = pos_indices.shape[1] # F_span - - sin_emb_timing_signal = self._get_timing_signal_1d_pos( - pos_indices, dtype=queries.dtype - ) # Shape [1, F_span, self.channels] - + sin_emb_timing_signal = self._get_timing_signal_1d_pos(pos_indices, dtype=self.dtype) # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H] projected_sin_emb = self.pos_proj(sin_emb_timing_signal) # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H] sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze( 0 ) # Shape [F, N, H] - - # term_ac: Query-Key content interaction - # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul - # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul - queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H] - keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C] - term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C] - - # term_bd: Query-Position interaction - # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb) - # queries shape: [B, U, W, N, H] - # sin_emb shape: [F, N, H] - # Target output shape: [B, N, U, W, F] - - # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb - q_permuted = queries.permute(0, 3, 1, 2, 4) - - # Permute sin_emb to [N, H, F] to prepare for matmul - # sin_emb original is [F, N, H] s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F] + return (s_permuted, max_span_plus_1) - # Reshape queries for matmul: [B, N, U*W, H] - q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim) - # Perform matmul: [B, N, U*W, H] @ [N, H, F] - # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F] - # Result: [B, N, U*W, F] - term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) +def apply_audio_rotary_embedding(queries, keys, postion_embeddings): + s_permuted, max_span_plus_1 = postion_embeddings + ##### SPLIT THIS + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape + _, _, key_context_size, _, _ = keys.shape - # Reshape to target [B, N, U, W, F] - term_bd_unshifed = term_bd_unshifed_matmul.reshape( - batch_size, - num_heads, - num_query_blocks, - query_block_size, - max_span_plus_1, - ) + q_permuted = queries.permute(0, 3, 1, 2, 4) + q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim) + queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H] + keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C] + term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C] - # Apply relative shift to term_bd_unshifed - term_bd_shifted = self._relative_shift( - term_bd_unshifed, - batch_size, - num_heads, - num_query_blocks, - query_block_size, - key_context_size, - max_span_plus_1, - ) # Shape [B, N, U, W, C] + # Perform matmul: [B, N, U*W, H] @ [N, H, F] + # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F] + # Result: [B, N, U*W, F] + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) - return term_ac + term_bd_shifted + # Reshape to target [B, N, U, W, F] + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + # Apply relative shift to term_bd_unshifed + term_bd_shifted = _relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) # Shape [B, N, U, W, C] + + return term_ac + term_bd_shifted + + +def _frame_hidden_states(config, hidden_states: torch.Tensor) -> torch.Tensor: + pad_left = config.max_past_horizon + pad_right = config.max_future_horizon + config.chunk_size - 1 + + # Pad only the time dimension (dim=1) + padding = (0, 0, pad_left, pad_right) # (dim2_left, dim2_right, dim1_left, dim1_right) + hidden_states = F.pad(hidden_states, padding) + + frame_len = config.context_size + frame_step = config.chunk_size + + # Unfold time dimension into overlapping frames + unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step) + + # Move the last dim (context window) to slot after batch + unfolded = unfolded.permute(0, 1, 3, 2) + return unfolded.contiguous() class Gemma3nAudioAttention(nn.Module): @@ -881,7 +881,7 @@ class Gemma3nAudioAttention(nn.Module): super().__init__() self.config = config - self.num_heads = self.config.conf_num_attention_heads + self.num_heads = self.config.num_attention_heads self.hidden_size = self.config.hidden_size self.head_dim = self.hidden_size // self.num_heads @@ -891,7 +891,6 @@ class Gemma3nAudioAttention(nn.Module): self.attention_logits_soft_cap = self.config.conf_attention_logit_cap self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon - self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config) self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -901,23 +900,45 @@ class Gemma3nAudioAttention(nn.Module): q_scale = self.head_dim**-0.5 r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0)) self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False) - - lower_causal_mask = torch.tril( - torch.ones((self.context_size, self.chunk_size), dtype=torch.bool), - diagonal=0, - ).T - upper_causal_mask = torch.tril( - torch.ones((self.chunk_size, self.context_size), dtype=torch.bool), - diagonal=self.max_past_horizon + self.max_future_horizon, - ) - local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool) - local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask - self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False) - self.register_buffer( "softcap", - torch.tensor(self.attention_logits_soft_cap).float(), + torch.tensor(self.config.attention_logits_soft_cap).float(), persistent=False, + ) # TODO do we even need a tensor for that? + + def forward( + self, hidden_states: torch.Tensor, position_embeddings, attention_mask: torch.BoolTensor, **kwargs + ) -> torch.Tensor: + hidden_shape = hidden_states.shape[:-1] + query_states = self.q_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous() + key_states = self.k_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous() + value_states = self.v_proj(hidden_states).reshape(*hidden_shape, self.head_dim).contiguous() + + per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale).view(1, 1, 1, self.head_dim) + scaling = self.q_scale * per_dim_scale_sp + + # shape is B, audio_length, num_heads, head_dim + num_blocks = (query_states.shape[1] + self.chunk_size - 1) // self.chunk_size + padding_length = num_blocks * self.chunk_size - query_states.shape[1] + query_states = torch.nn.functional.pad(query_states, (0, 0, 0, padding_length)) + query_states = query_states.reshape(-1, self.num_blocks, self.chunk_size, self.num_heads, self.head_dim) + + key_states = _frame_hidden_states(self.config, key_states) + value_states = _frame_hidden_states(self.config, value_states) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=scaling, + **kwargs, ) def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor: @@ -1089,15 +1110,14 @@ class Gemma3nAudioAttention(nn.Module): context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4) context_vectors = context_vectors.reshape( ( - batch_size, - num_query_blocks * self.chunk_size, + hidden_shape[0], + query_states.shape[1] * self.chunk_size, self.num_heads, self.head_dim, ) ) - context_vectors = context_vectors[:, :q_time] - - return context_vectors + context_vectors = context_vectors[:, : hidden_shape[1]] + return context_vectors, attn_weights class Gemma3nAudioCumulativeGroupNorm(nn.Module): @@ -1388,7 +1408,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module): self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False) self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False) self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) - self.post_layer_scale = torch.tensor(self.config.conf_residual_weight) + self.post_layer_scale = torch.tensor(self.config.residual_weight) def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: residual = audio_encodings @@ -1412,7 +1432,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module): self.depthwise_conv1d = nn.Conv1d( in_channels=self.config.hidden_size, out_channels=self.config.hidden_size, - kernel_size=self.config.conf_conv_kernel_size, + kernel_size=self.config.conv_kernel_size, stride=1, padding=0, # Manual causal padding groups=self.config.hidden_size, # Depthwise @@ -1422,7 +1442,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module): self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False) - self.causal_padding = self.config.conf_conv_kernel_size - 1 + self.causal_padding = self.config.conv_kernel_size - 1 def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: audio_encodings_residual = audio_encodings # Save for residual connection @@ -1484,9 +1504,7 @@ class Gemma3nAudioEncoder(PreTrainedModel): self.config = config self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config) - self.conformer = nn.ModuleList( - [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)] - ) + self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)]) def forward( self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor @@ -1535,10 +1553,10 @@ class Gemma3nAudioEncoder(PreTrainedModel): for block in self.conformer: audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask - if self.config.conf_reduction_factor > 1: - audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] + if self.config.reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.config.reduction_factor] # Reduce the mask as well - current_mask = current_mask[:, :: self.config.conf_reduction_factor] + current_mask = current_mask[:, :: self.config.reduction_factor] audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) return audio_encodings, current_mask