This commit is contained in:
Arthur 2025-07-02 21:52:09 +02:00 committed by GitHub
commit 7f478764c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 237 additions and 230 deletions

View File

@ -325,33 +325,24 @@ class Gemma3nAudioConfig(PretrainedConfig):
The epsilon used by the rms normalization layers.
gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
Clipping value used to stablize extremely large gradient values.
conf_attention_chunk_size (`int`, *optional*, defaults to 12):
The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_left (`int`, *optional*, defaults to 13):
The left context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_right (`int`, *optional*, defaults to 0):
The right context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
Logit cap applied during local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads in local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_hidden_layers (`int`, *optional*, defaults to 12):
The number of layers that use local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_conv_kernel_size (`int`, *optional*, defaults to 5):
Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_reduction_factor (`int`, *optional*, defaults to 4):
Reduction factor used in the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_residual_weight (`float`, *optional*, defaults to 0.5):
Residual connection weight inside the Conformer ("conf") section of the
Universal Speech Model.
attention_chunk_size (`int`, *optional*, defaults to 12):
The sub-sequence size for local attention processing.
attention_context_left (`int`, *optional*, defaults to 13):
The left context size of the local attention.
attention_context_right (`int`, *optional*, defaults to 0):
The right context size of the local attention.
attention_logit_cap (`float`, *optional*, defaults to 50.0):
Logit cap applied during local attention.
num_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads in local attention.
num_hidden_layers (`int`, *optional*, defaults to 12):
The number of layers that use local attention.
conv_kernel_size (`int`, *optional*, defaults to 5):
Convolution kernel size for the conformer block.
reduction_factor (`int`, *optional*, defaults to 4):
Reduction factor used in the conformer block.
residual_weight (`float`, *optional*, defaults to 0.5):
Residual connection weight.
sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
("sscp") section of the Universal Speech Model.
@ -395,15 +386,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
hidden_size: int = 1536,
rms_norm_eps: float = 1e-6,
gradient_clipping: float = 10_000_000_000.0,
conf_attention_chunk_size: int = 12,
conf_attention_context_left: int = 13,
conf_attention_context_right: int = 0,
conf_attention_logit_cap: float = 50.0,
conf_num_attention_heads: int = 8,
conf_num_hidden_layers: int = 12,
conf_conv_kernel_size: int = 5,
conf_reduction_factor: int = 4,
conf_residual_weight: float = 0.5,
attention_chunk_size: int = 12,
attention_context_left: int = 13,
attention_context_right: int = 0,
attention_logit_cap: float = 50.0,
num_attention_heads: int = 8,
num_hidden_layers: int = 12,
conv_kernel_size: int = 5,
reduction_factor: int = 4,
residual_weight: float = 0.5,
sscp_conv_channel_size: tuple[int, int] = (128, 32),
sscp_conv_group_norm_eps: float = 1e-3,
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
@ -423,15 +414,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.gradient_clipping = gradient_clipping
self.conf_attention_chunk_size = conf_attention_chunk_size
self.conf_attention_context_left = conf_attention_context_left
self.conf_attention_context_right = conf_attention_context_right
self.conf_attention_logit_cap = conf_attention_logit_cap
self.conf_num_attention_heads = conf_num_attention_heads
self.conf_num_hidden_layers = conf_num_hidden_layers
self.conf_conv_kernel_size = conf_conv_kernel_size
self.conf_reduction_factor = conf_reduction_factor
self.conf_residual_weight = conf_residual_weight
self.attention_chunk_size = attention_chunk_size
self.attention_context_left = attention_context_left
self.attention_context_right = attention_context_right
self.attention_logit_cap = attention_logit_cap
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.conv_kernel_size = conv_kernel_size
self.reduction_factor = reduction_factor
self.residual_weight = residual_weight
self.sscp_conv_channel_size = sscp_conv_channel_size
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
self.sscp_conv_kernel_size = sscp_conv_kernel_size

View File

@ -240,7 +240,7 @@ def convert_audio_encoder_weights(
converted_weights: list[Any] = []
if path.startswith(_AUDIO_ENCODER_CONFORMER):
assert weights.shape[0] == config.conf_num_hidden_layers
assert weights.shape[0] == config.num_hidden_layers
for i, matrix in enumerate(weights):
if "fflayer_end" in path:

View File

@ -149,11 +149,11 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
self.max_forward = self.config.conf_attention_context_right
self.max_backward = max(0, self.config.attention_context_left - 1)
self.max_forward = self.config.attention_context_right
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
@ -319,7 +319,7 @@ class Gemma3nAudioAttention(nn.Module):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
@ -826,7 +826,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
self.post_layer_scale = torch.tensor(self.config.residual_weight)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
residual = audio_encodings
@ -850,7 +850,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.depthwise_conv1d = nn.Conv1d(
in_channels=self.config.hidden_size,
out_channels=self.config.hidden_size,
kernel_size=self.config.conf_conv_kernel_size,
kernel_size=self.config.conv_kernel_size,
stride=1,
padding=0, # Manual causal padding
groups=self.config.hidden_size, # Depthwise
@ -860,7 +860,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
self.causal_padding = self.config.conf_conv_kernel_size - 1
self.causal_padding = self.config.conv_kernel_size - 1
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_residual = audio_encodings # Save for residual connection
@ -922,9 +922,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)])
def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
@ -973,10 +971,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
for block in self.conformer:
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
if self.config.conf_reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
if self.config.reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.reduction_factor]
# Reduce the mask as well
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
current_mask = current_mask[:, :: self.config.reduction_factor]
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
return audio_encodings, current_mask

View File

@ -21,6 +21,7 @@ from typing import Any, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.phi4_multimodal.modeling_phi4_multimodal import unfold_tensor
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, HybridCache
@ -407,15 +408,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
hidden_size: int = 1536,
rms_norm_eps: float = 1e-6,
gradient_clipping: float = 10_000_000_000.0,
conf_attention_chunk_size: int = 12,
conf_attention_context_left: int = 13,
conf_attention_context_right: int = 0,
conf_attention_logit_cap: float = 50.0,
conf_num_attention_heads: int = 8,
conf_num_hidden_layers: int = 12,
conf_conv_kernel_size: int = 5,
conf_reduction_factor: int = 4,
conf_residual_weight: float = 0.5,
attention_chunk_size: int = 12,
attention_context_left: int = 13,
attention_context_right: int = 0,
attention_logit_cap: float = 50.0,
num_attention_heads: int = 8,
num_hidden_layers: int = 12,
conv_kernel_size: int = 5,
reduction_factor: int = 4,
residual_weight: float = 0.5,
sscp_conv_channel_size: tuple[int, int] = (128, 32),
sscp_conv_group_norm_eps: float = 1e-3,
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
@ -435,15 +436,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.gradient_clipping = gradient_clipping
self.conf_attention_chunk_size = conf_attention_chunk_size
self.conf_attention_context_left = conf_attention_context_left
self.conf_attention_context_right = conf_attention_context_right
self.conf_attention_logit_cap = conf_attention_logit_cap
self.conf_num_attention_heads = conf_num_attention_heads
self.conf_num_hidden_layers = conf_num_hidden_layers
self.conf_conv_kernel_size = conf_conv_kernel_size
self.conf_reduction_factor = conf_reduction_factor
self.conf_residual_weight = conf_residual_weight
self.attention_chunk_size = attention_chunk_size
self.attention_context_left = attention_context_left
self.attention_context_right = attention_context_right
self.attention_logit_cap = attention_logit_cap
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.conv_kernel_size = conv_kernel_size
self.reduction_factor = reduction_factor
self.residual_weight = residual_weight
self.sscp_conv_channel_size = sscp_conv_channel_size
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
self.sscp_conv_kernel_size = sscp_conv_kernel_size
@ -706,16 +707,77 @@ class Gemma3nRMSNorm(Gemma3RMSNorm):
# ==== Audio Encoder ====
def _relative_shift(
term_bd_before_shift: torch.Tensor,
batch_size: int,
num_heads: int,
num_query_blocks: int,
query_block_size: int,
key_context_size: int,
max_span_plus_1: int,
) -> torch.Tensor:
"""Performs the relative shift.
Args:
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
(B), num_heads (N), num_query_blocks (U), query_block_size (W),
key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
Returns:
Tensor of shape [B, N, U, W, C].
"""
# term_bd_before_shift shape: [B, N, U, W, F_span]
# Target shape after shift: [B, N, U, W, C]
# Padding amount for the last dimension (F_span) to become (C + 1)
# C = key_context_size
# F_span = max_span_plus_1
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
# PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
# We only pad the last dimension on the right.
padding_tuple = (0, pad_amount_last_dim)
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
# Shape after pad: [B, N, U, W, C+1]
# Reshape for slicing (emulating JAX's behavior)
# [B, N, U, W * (C+1)]
term_bd_reshaped = term_bd_padded.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size * (key_context_size + 1),
)
)
# Slice to effective [B, N, U, W * C]
term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
# Reshape back to [B, N, U, W, C]
term_bd_shifted = term_bd_sliced.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
)
)
return term_bd_shifted
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
self.max_forward = self.config.conf_attention_context_right
self.max_backward = max(0, self.config.attention_context_left - 1)
self.max_forward = self.config.attention_context_right
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
@ -736,144 +798,82 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
return timing_signal.type(dtype)
def _relative_shift(
self,
term_bd_before_shift: torch.Tensor,
batch_size: int,
num_heads: int,
num_query_blocks: int,
query_block_size: int,
key_context_size: int,
max_span_plus_1: int,
) -> torch.Tensor:
"""Performs the relative shift.
Args:
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
(B), num_heads (N), num_query_blocks (U), query_block_size (W),
key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
Returns:
Tensor of shape [B, N, U, W, C].
"""
# term_bd_before_shift shape: [B, N, U, W, F_span]
# Target shape after shift: [B, N, U, W, C]
# Padding amount for the last dimension (F_span) to become (C + 1)
# C = key_context_size
# F_span = max_span_plus_1
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
# PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
# We only pad the last dimension on the right.
padding_tuple = (0, pad_amount_last_dim)
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
# Shape after pad: [B, N, U, W, C+1]
# Reshape for slicing (emulating JAX's behavior)
# [B, N, U, W * (C+1)]
term_bd_reshaped = term_bd_padded.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size * (key_context_size + 1),
)
)
# Slice to effective [B, N, U, W * C]
term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
# Reshape back to [B, N, U, W, C]
term_bd_shifted = term_bd_sliced.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
)
)
return term_bd_shifted
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
# queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
# keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
# C = W + L + R (key_context_size)
# F_span = L + R + 1 (max_span + 1)
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
_, _, key_context_size, _, _ = keys.shape
def forward(self) -> torch.Tensor:
# Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
# Length is L+R+1 = self.max_span + 1
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=self.device).unsqueeze(
0
) # Shape [1, F_span]
max_span_plus_1 = pos_indices.shape[1] # F_span
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
pos_indices, dtype=queries.dtype
) # Shape [1, F_span, self.channels]
sin_emb_timing_signal = self._get_timing_signal_1d_pos(pos_indices, dtype=self.dtype)
# Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
# Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
0
) # Shape [F, N, H]
# term_ac: Query-Key content interaction
# queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
# keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
# term_bd: Query-Position interaction
# Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
# queries shape: [B, U, W, N, H]
# sin_emb shape: [F, N, H]
# Target output shape: [B, N, U, W, F]
# Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
q_permuted = queries.permute(0, 3, 1, 2, 4)
# Permute sin_emb to [N, H, F] to prepare for matmul
# sin_emb original is [F, N, H]
s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
return (s_permuted, max_span_plus_1)
# Reshape queries for matmul: [B, N, U*W, H]
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
# Result: [B, N, U*W, F]
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
def apply_audio_rotary_embedding(queries, keys, postion_embeddings):
s_permuted, max_span_plus_1 = postion_embeddings
##### SPLIT THIS
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
_, _, key_context_size, _, _ = keys.shape
# Reshape to target [B, N, U, W, F]
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
max_span_plus_1,
)
q_permuted = queries.permute(0, 3, 1, 2, 4)
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
# Apply relative shift to term_bd_unshifed
term_bd_shifted = self._relative_shift(
term_bd_unshifed,
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
max_span_plus_1,
) # Shape [B, N, U, W, C]
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
# Result: [B, N, U*W, F]
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
return term_ac + term_bd_shifted
# Reshape to target [B, N, U, W, F]
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
max_span_plus_1,
)
# Apply relative shift to term_bd_unshifed
term_bd_shifted = _relative_shift(
term_bd_unshifed,
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
max_span_plus_1,
) # Shape [B, N, U, W, C]
return term_ac + term_bd_shifted
def _frame_hidden_states(config, hidden_states: torch.Tensor) -> torch.Tensor:
pad_left = config.max_past_horizon
pad_right = config.max_future_horizon + config.chunk_size - 1
# Pad only the time dimension (dim=1)
padding = (0, 0, pad_left, pad_right) # (dim2_left, dim2_right, dim1_left, dim1_right)
hidden_states = F.pad(hidden_states, padding)
frame_len = config.context_size
frame_step = config.chunk_size
# Unfold time dimension into overlapping frames
unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
# Move the last dim (context window) to slot after batch
unfolded = unfolded.permute(0, 1, 3, 2)
return unfolded.contiguous()
class Gemma3nAudioAttention(nn.Module):
@ -881,7 +881,7 @@ class Gemma3nAudioAttention(nn.Module):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.num_heads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
@ -891,7 +891,6 @@ class Gemma3nAudioAttention(nn.Module):
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
@ -901,23 +900,45 @@ class Gemma3nAudioAttention(nn.Module):
q_scale = self.head_dim**-0.5
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
lower_causal_mask = torch.tril(
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
diagonal=0,
).T
upper_causal_mask = torch.tril(
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
diagonal=self.max_past_horizon + self.max_future_horizon,
)
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
self.register_buffer(
"softcap",
torch.tensor(self.attention_logits_soft_cap).float(),
torch.tensor(self.config.attention_logits_soft_cap).float(),
persistent=False,
) # TODO do we even need a tensor for that?
def forward(
self, hidden_states: torch.Tensor, position_embeddings, attention_mask: torch.BoolTensor, **kwargs
) -> torch.Tensor:
hidden_shape = hidden_states.shape[:-1]
query_states = self.q_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous()
key_states = self.k_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous()
value_states = self.v_proj(hidden_states).reshape(*hidden_shape, self.head_dim).contiguous()
per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale).view(1, 1, 1, self.head_dim)
scaling = self.q_scale * per_dim_scale_sp
# shape is B, audio_length, num_heads, head_dim
num_blocks = (query_states.shape[1] + self.chunk_size - 1) // self.chunk_size
padding_length = num_blocks * self.chunk_size - query_states.shape[1]
query_states = torch.nn.functional.pad(query_states, (0, 0, 0, padding_length))
query_states = query_states.reshape(-1, self.num_blocks, self.chunk_size, self.num_heads, self.head_dim)
key_states = _frame_hidden_states(self.config, key_states)
value_states = _frame_hidden_states(self.config, value_states)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=scaling,
**kwargs,
)
def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
@ -1089,15 +1110,14 @@ class Gemma3nAudioAttention(nn.Module):
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
context_vectors = context_vectors.reshape(
(
batch_size,
num_query_blocks * self.chunk_size,
hidden_shape[0],
query_states.shape[1] * self.chunk_size,
self.num_heads,
self.head_dim,
)
)
context_vectors = context_vectors[:, :q_time]
return context_vectors
context_vectors = context_vectors[:, : hidden_shape[1]]
return context_vectors, attn_weights
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
@ -1388,7 +1408,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
self.post_layer_scale = torch.tensor(self.config.residual_weight)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
residual = audio_encodings
@ -1412,7 +1432,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.depthwise_conv1d = nn.Conv1d(
in_channels=self.config.hidden_size,
out_channels=self.config.hidden_size,
kernel_size=self.config.conf_conv_kernel_size,
kernel_size=self.config.conv_kernel_size,
stride=1,
padding=0, # Manual causal padding
groups=self.config.hidden_size, # Depthwise
@ -1422,7 +1442,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
self.causal_padding = self.config.conf_conv_kernel_size - 1
self.causal_padding = self.config.conv_kernel_size - 1
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_residual = audio_encodings # Save for residual connection
@ -1484,9 +1504,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)])
def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
@ -1535,10 +1553,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
for block in self.conformer:
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
if self.config.conf_reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
if self.config.reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.reduction_factor]
# Reduce the mask as well
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
current_mask = current_mask[:, :: self.config.reduction_factor]
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
return audio_encodings, current_mask