push current change
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled

This commit is contained in:
Arthur 2025-06-26 17:19:49 +02:00
parent 9c8d3a70b8
commit 7478568172
4 changed files with 237 additions and 230 deletions

View File

@ -325,33 +325,24 @@ class Gemma3nAudioConfig(PretrainedConfig):
The epsilon used by the rms normalization layers.
gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
Clipping value used to stablize extremely large gradient values.
conf_attention_chunk_size (`int`, *optional*, defaults to 12):
The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_left (`int`, *optional*, defaults to 13):
The left context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_right (`int`, *optional*, defaults to 0):
The right context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
Logit cap applied during local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads in local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_hidden_layers (`int`, *optional*, defaults to 12):
The number of layers that use local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_conv_kernel_size (`int`, *optional*, defaults to 5):
Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_reduction_factor (`int`, *optional*, defaults to 4):
Reduction factor used in the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_residual_weight (`float`, *optional*, defaults to 0.5):
Residual connection weight inside the Conformer ("conf") section of the
Universal Speech Model.
attention_chunk_size (`int`, *optional*, defaults to 12):
The sub-sequence size for local attention processing.
attention_context_left (`int`, *optional*, defaults to 13):
The left context size of the local attention.
attention_context_right (`int`, *optional*, defaults to 0):
The right context size of the local attention.
attention_logit_cap (`float`, *optional*, defaults to 50.0):
Logit cap applied during local attention.
num_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads in local attention.
num_hidden_layers (`int`, *optional*, defaults to 12):
The number of layers that use local attention.
conv_kernel_size (`int`, *optional*, defaults to 5):
Convolution kernel size for the conformer block.
reduction_factor (`int`, *optional*, defaults to 4):
Reduction factor used in the conformer block.
residual_weight (`float`, *optional*, defaults to 0.5):
Residual connection weight.
sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
("sscp") section of the Universal Speech Model.
@ -395,15 +386,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
hidden_size: int = 1536,
rms_norm_eps: float = 1e-6,
gradient_clipping: float = 10_000_000_000.0,
conf_attention_chunk_size: int = 12,
conf_attention_context_left: int = 13,
conf_attention_context_right: int = 0,
conf_attention_logit_cap: float = 50.0,
conf_num_attention_heads: int = 8,
conf_num_hidden_layers: int = 12,
conf_conv_kernel_size: int = 5,
conf_reduction_factor: int = 4,
conf_residual_weight: float = 0.5,
attention_chunk_size: int = 12,
attention_context_left: int = 13,
attention_context_right: int = 0,
attention_logit_cap: float = 50.0,
num_attention_heads: int = 8,
num_hidden_layers: int = 12,
conv_kernel_size: int = 5,
reduction_factor: int = 4,
residual_weight: float = 0.5,
sscp_conv_channel_size: tuple[int, int] = (128, 32),
sscp_conv_group_norm_eps: float = 1e-3,
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
@ -423,15 +414,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.gradient_clipping = gradient_clipping
self.conf_attention_chunk_size = conf_attention_chunk_size
self.conf_attention_context_left = conf_attention_context_left
self.conf_attention_context_right = conf_attention_context_right
self.conf_attention_logit_cap = conf_attention_logit_cap
self.conf_num_attention_heads = conf_num_attention_heads
self.conf_num_hidden_layers = conf_num_hidden_layers
self.conf_conv_kernel_size = conf_conv_kernel_size
self.conf_reduction_factor = conf_reduction_factor
self.conf_residual_weight = conf_residual_weight
self.attention_chunk_size = attention_chunk_size
self.attention_context_left = attention_context_left
self.attention_context_right = attention_context_right
self.attention_logit_cap = attention_logit_cap
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.conv_kernel_size = conv_kernel_size
self.reduction_factor = reduction_factor
self.residual_weight = residual_weight
self.sscp_conv_channel_size = sscp_conv_channel_size
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
self.sscp_conv_kernel_size = sscp_conv_kernel_size

View File

@ -240,7 +240,7 @@ def convert_audio_encoder_weights(
converted_weights: list[Any] = []
if path.startswith(_AUDIO_ENCODER_CONFORMER):
assert weights.shape[0] == config.conf_num_hidden_layers
assert weights.shape[0] == config.num_hidden_layers
for i, matrix in enumerate(weights):
if "fflayer_end" in path:

View File

@ -149,11 +149,11 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
self.max_forward = self.config.conf_attention_context_right
self.max_backward = max(0, self.config.attention_context_left - 1)
self.max_forward = self.config.attention_context_right
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
@ -319,7 +319,7 @@ class Gemma3nAudioAttention(nn.Module):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
@ -826,7 +826,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
self.post_layer_scale = torch.tensor(self.config.residual_weight)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
residual = audio_encodings
@ -850,7 +850,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.depthwise_conv1d = nn.Conv1d(
in_channels=self.config.hidden_size,
out_channels=self.config.hidden_size,
kernel_size=self.config.conf_conv_kernel_size,
kernel_size=self.config.conv_kernel_size,
stride=1,
padding=0, # Manual causal padding
groups=self.config.hidden_size, # Depthwise
@ -860,7 +860,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
self.causal_padding = self.config.conf_conv_kernel_size - 1
self.causal_padding = self.config.conv_kernel_size - 1
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_residual = audio_encodings # Save for residual connection
@ -922,9 +922,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)])
def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
@ -973,10 +971,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
for block in self.conformer:
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
if self.config.conf_reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
if self.config.reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.reduction_factor]
# Reduce the mask as well
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
current_mask = current_mask[:, :: self.config.reduction_factor]
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
return audio_encodings, current_mask

View File

@ -21,6 +21,7 @@ from typing import Any, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.phi4_multimodal.modeling_phi4_multimodal import unfold_tensor
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, HybridCache
@ -407,15 +408,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
hidden_size: int = 1536,
rms_norm_eps: float = 1e-6,
gradient_clipping: float = 10_000_000_000.0,
conf_attention_chunk_size: int = 12,
conf_attention_context_left: int = 13,
conf_attention_context_right: int = 0,
conf_attention_logit_cap: float = 50.0,
conf_num_attention_heads: int = 8,
conf_num_hidden_layers: int = 12,
conf_conv_kernel_size: int = 5,
conf_reduction_factor: int = 4,
conf_residual_weight: float = 0.5,
attention_chunk_size: int = 12,
attention_context_left: int = 13,
attention_context_right: int = 0,
attention_logit_cap: float = 50.0,
num_attention_heads: int = 8,
num_hidden_layers: int = 12,
conv_kernel_size: int = 5,
reduction_factor: int = 4,
residual_weight: float = 0.5,
sscp_conv_channel_size: tuple[int, int] = (128, 32),
sscp_conv_group_norm_eps: float = 1e-3,
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
@ -435,15 +436,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.gradient_clipping = gradient_clipping
self.conf_attention_chunk_size = conf_attention_chunk_size
self.conf_attention_context_left = conf_attention_context_left
self.conf_attention_context_right = conf_attention_context_right
self.conf_attention_logit_cap = conf_attention_logit_cap
self.conf_num_attention_heads = conf_num_attention_heads
self.conf_num_hidden_layers = conf_num_hidden_layers
self.conf_conv_kernel_size = conf_conv_kernel_size
self.conf_reduction_factor = conf_reduction_factor
self.conf_residual_weight = conf_residual_weight
self.attention_chunk_size = attention_chunk_size
self.attention_context_left = attention_context_left
self.attention_context_right = attention_context_right
self.attention_logit_cap = attention_logit_cap
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.conv_kernel_size = conv_kernel_size
self.reduction_factor = reduction_factor
self.residual_weight = residual_weight
self.sscp_conv_channel_size = sscp_conv_channel_size
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
self.sscp_conv_kernel_size = sscp_conv_kernel_size
@ -706,38 +707,7 @@ class Gemma3nRMSNorm(Gemma3RMSNorm):
# ==== Audio Encoder ====
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
self.max_forward = self.config.conf_attention_context_right
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
min_timescale = 1.0
max_timescale = 1.0e4
num_timescales = self.channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
self.register_buffer(
"inv_timescales",
inv_timescales.float().unsqueeze(0).unsqueeze(0),
persistent=False,
)
def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
position = position.float().unsqueeze(-1)
scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
return timing_signal.type(dtype)
def _relative_shift(
self,
def _relative_shift(
term_bd_before_shift: torch.Tensor,
batch_size: int,
num_heads: int,
@ -745,7 +715,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
query_block_size: int,
key_context_size: int,
max_span_plus_1: int,
) -> torch.Tensor:
) -> torch.Tensor:
"""Performs the relative shift.
Args:
@ -797,57 +767,68 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
)
return term_bd_shifted
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
# queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
# keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
# C = W + L + R (key_context_size)
# F_span = L + R + 1 (max_span + 1)
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
_, _, key_context_size, _, _ = keys.shape
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.num_heads = self.config.num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.attention_context_left - 1)
self.max_forward = self.config.attention_context_right
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
min_timescale = 1.0
max_timescale = 1.0e4
num_timescales = self.channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
self.register_buffer(
"inv_timescales",
inv_timescales.float().unsqueeze(0).unsqueeze(0),
persistent=False,
)
def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
position = position.float().unsqueeze(-1)
scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
return timing_signal.type(dtype)
def forward(self) -> torch.Tensor:
# Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
# Length is L+R+1 = self.max_span + 1
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=self.device).unsqueeze(
0
) # Shape [1, F_span]
max_span_plus_1 = pos_indices.shape[1] # F_span
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
pos_indices, dtype=queries.dtype
) # Shape [1, F_span, self.channels]
sin_emb_timing_signal = self._get_timing_signal_1d_pos(pos_indices, dtype=self.dtype)
# Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
# Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
0
) # Shape [F, N, H]
s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
return (s_permuted, max_span_plus_1)
# term_ac: Query-Key content interaction
# queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
# keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
def apply_audio_rotary_embedding(queries, keys, postion_embeddings):
s_permuted, max_span_plus_1 = postion_embeddings
##### SPLIT THIS
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
_, _, key_context_size, _, _ = keys.shape
q_permuted = queries.permute(0, 3, 1, 2, 4)
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
# term_bd: Query-Position interaction
# Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
# queries shape: [B, U, W, N, H]
# sin_emb shape: [F, N, H]
# Target output shape: [B, N, U, W, F]
# Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
q_permuted = queries.permute(0, 3, 1, 2, 4)
# Permute sin_emb to [N, H, F] to prepare for matmul
# sin_emb original is [F, N, H]
s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
# Reshape queries for matmul: [B, N, U*W, H]
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
# Result: [B, N, U*W, F]
@ -863,7 +844,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
)
# Apply relative shift to term_bd_unshifed
term_bd_shifted = self._relative_shift(
term_bd_shifted = _relative_shift(
term_bd_unshifed,
batch_size,
num_heads,
@ -876,12 +857,31 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
return term_ac + term_bd_shifted
def _frame_hidden_states(config, hidden_states: torch.Tensor) -> torch.Tensor:
pad_left = config.max_past_horizon
pad_right = config.max_future_horizon + config.chunk_size - 1
# Pad only the time dimension (dim=1)
padding = (0, 0, pad_left, pad_right) # (dim2_left, dim2_right, dim1_left, dim1_right)
hidden_states = F.pad(hidden_states, padding)
frame_len = config.context_size
frame_step = config.chunk_size
# Unfold time dimension into overlapping frames
unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
# Move the last dim (context window) to slot after batch
unfolded = unfolded.permute(0, 1, 3, 2)
return unfolded.contiguous()
class Gemma3nAudioAttention(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
@ -891,7 +891,6 @@ class Gemma3nAudioAttention(nn.Module):
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
@ -901,23 +900,45 @@ class Gemma3nAudioAttention(nn.Module):
q_scale = self.head_dim**-0.5
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
lower_causal_mask = torch.tril(
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
diagonal=0,
).T
upper_causal_mask = torch.tril(
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
diagonal=self.max_past_horizon + self.max_future_horizon,
)
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
self.register_buffer(
"softcap",
torch.tensor(self.attention_logits_soft_cap).float(),
torch.tensor(self.config.attention_logits_soft_cap).float(),
persistent=False,
) # TODO do we even need a tensor for that?
def forward(
self, hidden_states: torch.Tensor, position_embeddings, attention_mask: torch.BoolTensor, **kwargs
) -> torch.Tensor:
hidden_shape = hidden_states.shape[:-1]
query_states = self.q_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous()
key_states = self.k_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous()
value_states = self.v_proj(hidden_states).reshape(*hidden_shape, self.head_dim).contiguous()
per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale).view(1, 1, 1, self.head_dim)
scaling = self.q_scale * per_dim_scale_sp
# shape is B, audio_length, num_heads, head_dim
num_blocks = (query_states.shape[1] + self.chunk_size - 1) // self.chunk_size
padding_length = num_blocks * self.chunk_size - query_states.shape[1]
query_states = torch.nn.functional.pad(query_states, (0, 0, 0, padding_length))
query_states = query_states.reshape(-1, self.num_blocks, self.chunk_size, self.num_heads, self.head_dim)
key_states = _frame_hidden_states(self.config, key_states)
value_states = _frame_hidden_states(self.config, value_states)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=scaling,
**kwargs,
)
def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
@ -1089,15 +1110,14 @@ class Gemma3nAudioAttention(nn.Module):
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
context_vectors = context_vectors.reshape(
(
batch_size,
num_query_blocks * self.chunk_size,
hidden_shape[0],
query_states.shape[1] * self.chunk_size,
self.num_heads,
self.head_dim,
)
)
context_vectors = context_vectors[:, :q_time]
return context_vectors
context_vectors = context_vectors[:, : hidden_shape[1]]
return context_vectors, attn_weights
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
@ -1388,7 +1408,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
self.post_layer_scale = torch.tensor(self.config.residual_weight)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
residual = audio_encodings
@ -1412,7 +1432,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.depthwise_conv1d = nn.Conv1d(
in_channels=self.config.hidden_size,
out_channels=self.config.hidden_size,
kernel_size=self.config.conf_conv_kernel_size,
kernel_size=self.config.conv_kernel_size,
stride=1,
padding=0, # Manual causal padding
groups=self.config.hidden_size, # Depthwise
@ -1422,7 +1442,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
self.causal_padding = self.config.conf_conv_kernel_size - 1
self.causal_padding = self.config.conv_kernel_size - 1
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_residual = audio_encodings # Save for residual connection
@ -1484,9 +1504,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)])
def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
@ -1535,10 +1553,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
for block in self.conformer:
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
if self.config.conf_reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
if self.config.reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.reduction_factor]
# Reduce the mask as well
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
current_mask = current_mask[:, :: self.config.reduction_factor]
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
return audio_encodings, current_mask