mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 7478568172
into 2d561713f8
This commit is contained in:
commit
7f478764c2
@ -325,33 +325,24 @@ class Gemma3nAudioConfig(PretrainedConfig):
|
||||
The epsilon used by the rms normalization layers.
|
||||
gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
|
||||
Clipping value used to stablize extremely large gradient values.
|
||||
conf_attention_chunk_size (`int`, *optional*, defaults to 12):
|
||||
The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_attention_context_left (`int`, *optional*, defaults to 13):
|
||||
The left context size of the local attention inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_attention_context_right (`int`, *optional*, defaults to 0):
|
||||
The right context size of the local attention inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
|
||||
Logit cap applied during local attention inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
The number of attention heads in local attention inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
The number of layers that use local attention inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_conv_kernel_size (`int`, *optional*, defaults to 5):
|
||||
Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_reduction_factor (`int`, *optional*, defaults to 4):
|
||||
Reduction factor used in the conformer block inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
conf_residual_weight (`float`, *optional*, defaults to 0.5):
|
||||
Residual connection weight inside the Conformer ("conf") section of the
|
||||
Universal Speech Model.
|
||||
attention_chunk_size (`int`, *optional*, defaults to 12):
|
||||
The sub-sequence size for local attention processing.
|
||||
attention_context_left (`int`, *optional*, defaults to 13):
|
||||
The left context size of the local attention.
|
||||
attention_context_right (`int`, *optional*, defaults to 0):
|
||||
The right context size of the local attention.
|
||||
attention_logit_cap (`float`, *optional*, defaults to 50.0):
|
||||
Logit cap applied during local attention.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
The number of attention heads in local attention.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
The number of layers that use local attention.
|
||||
conv_kernel_size (`int`, *optional*, defaults to 5):
|
||||
Convolution kernel size for the conformer block.
|
||||
reduction_factor (`int`, *optional*, defaults to 4):
|
||||
Reduction factor used in the conformer block.
|
||||
residual_weight (`float`, *optional*, defaults to 0.5):
|
||||
Residual connection weight.
|
||||
sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
|
||||
The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
|
||||
("sscp") section of the Universal Speech Model.
|
||||
@ -395,15 +386,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
|
||||
hidden_size: int = 1536,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
gradient_clipping: float = 10_000_000_000.0,
|
||||
conf_attention_chunk_size: int = 12,
|
||||
conf_attention_context_left: int = 13,
|
||||
conf_attention_context_right: int = 0,
|
||||
conf_attention_logit_cap: float = 50.0,
|
||||
conf_num_attention_heads: int = 8,
|
||||
conf_num_hidden_layers: int = 12,
|
||||
conf_conv_kernel_size: int = 5,
|
||||
conf_reduction_factor: int = 4,
|
||||
conf_residual_weight: float = 0.5,
|
||||
attention_chunk_size: int = 12,
|
||||
attention_context_left: int = 13,
|
||||
attention_context_right: int = 0,
|
||||
attention_logit_cap: float = 50.0,
|
||||
num_attention_heads: int = 8,
|
||||
num_hidden_layers: int = 12,
|
||||
conv_kernel_size: int = 5,
|
||||
reduction_factor: int = 4,
|
||||
residual_weight: float = 0.5,
|
||||
sscp_conv_channel_size: tuple[int, int] = (128, 32),
|
||||
sscp_conv_group_norm_eps: float = 1e-3,
|
||||
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
|
||||
@ -423,15 +414,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_offset = vocab_offset
|
||||
self.gradient_clipping = gradient_clipping
|
||||
self.conf_attention_chunk_size = conf_attention_chunk_size
|
||||
self.conf_attention_context_left = conf_attention_context_left
|
||||
self.conf_attention_context_right = conf_attention_context_right
|
||||
self.conf_attention_logit_cap = conf_attention_logit_cap
|
||||
self.conf_num_attention_heads = conf_num_attention_heads
|
||||
self.conf_num_hidden_layers = conf_num_hidden_layers
|
||||
self.conf_conv_kernel_size = conf_conv_kernel_size
|
||||
self.conf_reduction_factor = conf_reduction_factor
|
||||
self.conf_residual_weight = conf_residual_weight
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
self.attention_context_left = attention_context_left
|
||||
self.attention_context_right = attention_context_right
|
||||
self.attention_logit_cap = attention_logit_cap
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.reduction_factor = reduction_factor
|
||||
self.residual_weight = residual_weight
|
||||
self.sscp_conv_channel_size = sscp_conv_channel_size
|
||||
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
|
||||
self.sscp_conv_kernel_size = sscp_conv_kernel_size
|
||||
|
@ -240,7 +240,7 @@ def convert_audio_encoder_weights(
|
||||
converted_weights: list[Any] = []
|
||||
|
||||
if path.startswith(_AUDIO_ENCODER_CONFORMER):
|
||||
assert weights.shape[0] == config.conf_num_hidden_layers
|
||||
assert weights.shape[0] == config.num_hidden_layers
|
||||
|
||||
for i, matrix in enumerate(weights):
|
||||
if "fflayer_end" in path:
|
||||
|
@ -149,11 +149,11 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_heads = self.config.conf_num_attention_heads
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.channels = self.config.hidden_size
|
||||
self.head_dim = self.channels // self.num_heads
|
||||
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
|
||||
self.max_forward = self.config.conf_attention_context_right
|
||||
self.max_backward = max(0, self.config.attention_context_left - 1)
|
||||
self.max_forward = self.config.attention_context_right
|
||||
|
||||
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
|
||||
|
||||
@ -319,7 +319,7 @@ class Gemma3nAudioAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_heads = self.config.conf_num_attention_heads
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.hidden_size = self.config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
@ -826,7 +826,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
|
||||
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
|
||||
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
|
||||
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
|
||||
self.post_layer_scale = torch.tensor(self.config.residual_weight)
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
residual = audio_encodings
|
||||
@ -850,7 +850,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
|
||||
self.depthwise_conv1d = nn.Conv1d(
|
||||
in_channels=self.config.hidden_size,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.conf_conv_kernel_size,
|
||||
kernel_size=self.config.conv_kernel_size,
|
||||
stride=1,
|
||||
padding=0, # Manual causal padding
|
||||
groups=self.config.hidden_size, # Depthwise
|
||||
@ -860,7 +860,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
|
||||
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
||||
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
||||
|
||||
self.causal_padding = self.config.conf_conv_kernel_size - 1
|
||||
self.causal_padding = self.config.conv_kernel_size - 1
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
audio_encodings_residual = audio_encodings # Save for residual connection
|
||||
@ -922,9 +922,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
self.config = config
|
||||
|
||||
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
|
||||
self.conformer = nn.ModuleList(
|
||||
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
|
||||
)
|
||||
self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
||||
@ -973,10 +971,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
for block in self.conformer:
|
||||
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
|
||||
|
||||
if self.config.conf_reduction_factor > 1:
|
||||
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
|
||||
if self.config.reduction_factor > 1:
|
||||
audio_encodings = audio_encodings[:, :: self.config.reduction_factor]
|
||||
# Reduce the mask as well
|
||||
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
|
||||
current_mask = current_mask[:, :: self.config.reduction_factor]
|
||||
|
||||
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
|
||||
return audio_encodings, current_mask
|
||||
|
@ -21,6 +21,7 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.phi4_multimodal.modeling_phi4_multimodal import unfold_tensor
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, HybridCache
|
||||
@ -407,15 +408,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
|
||||
hidden_size: int = 1536,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
gradient_clipping: float = 10_000_000_000.0,
|
||||
conf_attention_chunk_size: int = 12,
|
||||
conf_attention_context_left: int = 13,
|
||||
conf_attention_context_right: int = 0,
|
||||
conf_attention_logit_cap: float = 50.0,
|
||||
conf_num_attention_heads: int = 8,
|
||||
conf_num_hidden_layers: int = 12,
|
||||
conf_conv_kernel_size: int = 5,
|
||||
conf_reduction_factor: int = 4,
|
||||
conf_residual_weight: float = 0.5,
|
||||
attention_chunk_size: int = 12,
|
||||
attention_context_left: int = 13,
|
||||
attention_context_right: int = 0,
|
||||
attention_logit_cap: float = 50.0,
|
||||
num_attention_heads: int = 8,
|
||||
num_hidden_layers: int = 12,
|
||||
conv_kernel_size: int = 5,
|
||||
reduction_factor: int = 4,
|
||||
residual_weight: float = 0.5,
|
||||
sscp_conv_channel_size: tuple[int, int] = (128, 32),
|
||||
sscp_conv_group_norm_eps: float = 1e-3,
|
||||
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
|
||||
@ -435,15 +436,15 @@ class Gemma3nAudioConfig(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_offset = vocab_offset
|
||||
self.gradient_clipping = gradient_clipping
|
||||
self.conf_attention_chunk_size = conf_attention_chunk_size
|
||||
self.conf_attention_context_left = conf_attention_context_left
|
||||
self.conf_attention_context_right = conf_attention_context_right
|
||||
self.conf_attention_logit_cap = conf_attention_logit_cap
|
||||
self.conf_num_attention_heads = conf_num_attention_heads
|
||||
self.conf_num_hidden_layers = conf_num_hidden_layers
|
||||
self.conf_conv_kernel_size = conf_conv_kernel_size
|
||||
self.conf_reduction_factor = conf_reduction_factor
|
||||
self.conf_residual_weight = conf_residual_weight
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
self.attention_context_left = attention_context_left
|
||||
self.attention_context_right = attention_context_right
|
||||
self.attention_logit_cap = attention_logit_cap
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.reduction_factor = reduction_factor
|
||||
self.residual_weight = residual_weight
|
||||
self.sscp_conv_channel_size = sscp_conv_channel_size
|
||||
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
|
||||
self.sscp_conv_kernel_size = sscp_conv_kernel_size
|
||||
@ -706,16 +707,77 @@ class Gemma3nRMSNorm(Gemma3RMSNorm):
|
||||
# ==== Audio Encoder ====
|
||||
|
||||
|
||||
def _relative_shift(
|
||||
term_bd_before_shift: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
num_query_blocks: int,
|
||||
query_block_size: int,
|
||||
key_context_size: int,
|
||||
max_span_plus_1: int,
|
||||
) -> torch.Tensor:
|
||||
"""Performs the relative shift.
|
||||
|
||||
Args:
|
||||
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
|
||||
(B), num_heads (N), num_query_blocks (U), query_block_size (W),
|
||||
key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
|
||||
|
||||
Returns:
|
||||
Tensor of shape [B, N, U, W, C].
|
||||
"""
|
||||
# term_bd_before_shift shape: [B, N, U, W, F_span]
|
||||
# Target shape after shift: [B, N, U, W, C]
|
||||
|
||||
# Padding amount for the last dimension (F_span) to become (C + 1)
|
||||
# C = key_context_size
|
||||
# F_span = max_span_plus_1
|
||||
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
|
||||
|
||||
# PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
|
||||
# We only pad the last dimension on the right.
|
||||
padding_tuple = (0, pad_amount_last_dim)
|
||||
|
||||
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
|
||||
# Shape after pad: [B, N, U, W, C+1]
|
||||
|
||||
# Reshape for slicing (emulating JAX's behavior)
|
||||
# [B, N, U, W * (C+1)]
|
||||
term_bd_reshaped = term_bd_padded.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size * (key_context_size + 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Slice to effective [B, N, U, W * C]
|
||||
term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
|
||||
|
||||
# Reshape back to [B, N, U, W, C]
|
||||
term_bd_shifted = term_bd_sliced.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
key_context_size,
|
||||
)
|
||||
)
|
||||
return term_bd_shifted
|
||||
|
||||
|
||||
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
||||
def __init__(self, config: Gemma3nAudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_heads = self.config.conf_num_attention_heads
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.channels = self.config.hidden_size
|
||||
self.head_dim = self.channels // self.num_heads
|
||||
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
|
||||
self.max_forward = self.config.conf_attention_context_right
|
||||
self.max_backward = max(0, self.config.attention_context_left - 1)
|
||||
self.max_forward = self.config.attention_context_right
|
||||
|
||||
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
|
||||
|
||||
@ -736,144 +798,82 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
||||
timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
|
||||
return timing_signal.type(dtype)
|
||||
|
||||
def _relative_shift(
|
||||
self,
|
||||
term_bd_before_shift: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
num_query_blocks: int,
|
||||
query_block_size: int,
|
||||
key_context_size: int,
|
||||
max_span_plus_1: int,
|
||||
) -> torch.Tensor:
|
||||
"""Performs the relative shift.
|
||||
|
||||
Args:
|
||||
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
|
||||
(B), num_heads (N), num_query_blocks (U), query_block_size (W),
|
||||
key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
|
||||
|
||||
Returns:
|
||||
Tensor of shape [B, N, U, W, C].
|
||||
"""
|
||||
# term_bd_before_shift shape: [B, N, U, W, F_span]
|
||||
# Target shape after shift: [B, N, U, W, C]
|
||||
|
||||
# Padding amount for the last dimension (F_span) to become (C + 1)
|
||||
# C = key_context_size
|
||||
# F_span = max_span_plus_1
|
||||
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
|
||||
|
||||
# PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
|
||||
# We only pad the last dimension on the right.
|
||||
padding_tuple = (0, pad_amount_last_dim)
|
||||
|
||||
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
|
||||
# Shape after pad: [B, N, U, W, C+1]
|
||||
|
||||
# Reshape for slicing (emulating JAX's behavior)
|
||||
# [B, N, U, W * (C+1)]
|
||||
term_bd_reshaped = term_bd_padded.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size * (key_context_size + 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Slice to effective [B, N, U, W * C]
|
||||
term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
|
||||
|
||||
# Reshape back to [B, N, U, W, C]
|
||||
term_bd_shifted = term_bd_sliced.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
key_context_size,
|
||||
)
|
||||
)
|
||||
return term_bd_shifted
|
||||
|
||||
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
|
||||
# queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
|
||||
# keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
|
||||
# C = W + L + R (key_context_size)
|
||||
# F_span = L + R + 1 (max_span + 1)
|
||||
|
||||
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
|
||||
_, _, key_context_size, _, _ = keys.shape
|
||||
|
||||
def forward(self) -> torch.Tensor:
|
||||
# Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
|
||||
# Length is L+R+1 = self.max_span + 1
|
||||
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
|
||||
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=self.device).unsqueeze(
|
||||
0
|
||||
) # Shape [1, F_span]
|
||||
|
||||
max_span_plus_1 = pos_indices.shape[1] # F_span
|
||||
|
||||
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
|
||||
pos_indices, dtype=queries.dtype
|
||||
) # Shape [1, F_span, self.channels]
|
||||
|
||||
sin_emb_timing_signal = self._get_timing_signal_1d_pos(pos_indices, dtype=self.dtype)
|
||||
# Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
|
||||
projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
|
||||
# Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
|
||||
sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
|
||||
0
|
||||
) # Shape [F, N, H]
|
||||
|
||||
# term_ac: Query-Key content interaction
|
||||
# queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
|
||||
# keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
|
||||
queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
|
||||
keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
|
||||
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
||||
|
||||
# term_bd: Query-Position interaction
|
||||
# Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
|
||||
# queries shape: [B, U, W, N, H]
|
||||
# sin_emb shape: [F, N, H]
|
||||
# Target output shape: [B, N, U, W, F]
|
||||
|
||||
# Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
|
||||
q_permuted = queries.permute(0, 3, 1, 2, 4)
|
||||
|
||||
# Permute sin_emb to [N, H, F] to prepare for matmul
|
||||
# sin_emb original is [F, N, H]
|
||||
s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
|
||||
return (s_permuted, max_span_plus_1)
|
||||
|
||||
# Reshape queries for matmul: [B, N, U*W, H]
|
||||
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
|
||||
|
||||
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
|
||||
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
|
||||
# Result: [B, N, U*W, F]
|
||||
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
|
||||
def apply_audio_rotary_embedding(queries, keys, postion_embeddings):
|
||||
s_permuted, max_span_plus_1 = postion_embeddings
|
||||
##### SPLIT THIS
|
||||
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
|
||||
_, _, key_context_size, _, _ = keys.shape
|
||||
|
||||
# Reshape to target [B, N, U, W, F]
|
||||
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
max_span_plus_1,
|
||||
)
|
||||
q_permuted = queries.permute(0, 3, 1, 2, 4)
|
||||
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
|
||||
queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
|
||||
keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
|
||||
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
||||
|
||||
# Apply relative shift to term_bd_unshifed
|
||||
term_bd_shifted = self._relative_shift(
|
||||
term_bd_unshifed,
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
key_context_size,
|
||||
max_span_plus_1,
|
||||
) # Shape [B, N, U, W, C]
|
||||
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
|
||||
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
|
||||
# Result: [B, N, U*W, F]
|
||||
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
|
||||
|
||||
return term_ac + term_bd_shifted
|
||||
# Reshape to target [B, N, U, W, F]
|
||||
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
max_span_plus_1,
|
||||
)
|
||||
|
||||
# Apply relative shift to term_bd_unshifed
|
||||
term_bd_shifted = _relative_shift(
|
||||
term_bd_unshifed,
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
key_context_size,
|
||||
max_span_plus_1,
|
||||
) # Shape [B, N, U, W, C]
|
||||
|
||||
return term_ac + term_bd_shifted
|
||||
|
||||
|
||||
def _frame_hidden_states(config, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
pad_left = config.max_past_horizon
|
||||
pad_right = config.max_future_horizon + config.chunk_size - 1
|
||||
|
||||
# Pad only the time dimension (dim=1)
|
||||
padding = (0, 0, pad_left, pad_right) # (dim2_left, dim2_right, dim1_left, dim1_right)
|
||||
hidden_states = F.pad(hidden_states, padding)
|
||||
|
||||
frame_len = config.context_size
|
||||
frame_step = config.chunk_size
|
||||
|
||||
# Unfold time dimension into overlapping frames
|
||||
unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
|
||||
|
||||
# Move the last dim (context window) to slot after batch
|
||||
unfolded = unfolded.permute(0, 1, 3, 2)
|
||||
return unfolded.contiguous()
|
||||
|
||||
|
||||
class Gemma3nAudioAttention(nn.Module):
|
||||
@ -881,7 +881,7 @@ class Gemma3nAudioAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_heads = self.config.conf_num_attention_heads
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.hidden_size = self.config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
@ -891,7 +891,6 @@ class Gemma3nAudioAttention(nn.Module):
|
||||
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
|
||||
self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
|
||||
|
||||
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
|
||||
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
@ -901,23 +900,45 @@ class Gemma3nAudioAttention(nn.Module):
|
||||
q_scale = self.head_dim**-0.5
|
||||
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
||||
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
||||
|
||||
lower_causal_mask = torch.tril(
|
||||
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
||||
diagonal=0,
|
||||
).T
|
||||
upper_causal_mask = torch.tril(
|
||||
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
|
||||
diagonal=self.max_past_horizon + self.max_future_horizon,
|
||||
)
|
||||
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
||||
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
||||
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
|
||||
|
||||
self.register_buffer(
|
||||
"softcap",
|
||||
torch.tensor(self.attention_logits_soft_cap).float(),
|
||||
torch.tensor(self.config.attention_logits_soft_cap).float(),
|
||||
persistent=False,
|
||||
) # TODO do we even need a tensor for that?
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, position_embeddings, attention_mask: torch.BoolTensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
hidden_shape = hidden_states.shape[:-1]
|
||||
query_states = self.q_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous()
|
||||
key_states = self.k_proj(hidden_states).reshape(*hidden_shape, self.num_heads, self.head_dim).contiguous()
|
||||
value_states = self.v_proj(hidden_states).reshape(*hidden_shape, self.head_dim).contiguous()
|
||||
|
||||
per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale).view(1, 1, 1, self.head_dim)
|
||||
scaling = self.q_scale * per_dim_scale_sp
|
||||
|
||||
# shape is B, audio_length, num_heads, head_dim
|
||||
num_blocks = (query_states.shape[1] + self.chunk_size - 1) // self.chunk_size
|
||||
padding_length = num_blocks * self.chunk_size - query_states.shape[1]
|
||||
query_states = torch.nn.functional.pad(query_states, (0, 0, 0, padding_length))
|
||||
query_states = query_states.reshape(-1, self.num_blocks, self.chunk_size, self.num_heads, self.head_dim)
|
||||
|
||||
key_states = _frame_hidden_states(self.config, key_states)
|
||||
value_states = _frame_hidden_states(self.config, value_states)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
|
||||
@ -1089,15 +1110,14 @@ class Gemma3nAudioAttention(nn.Module):
|
||||
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
|
||||
context_vectors = context_vectors.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_query_blocks * self.chunk_size,
|
||||
hidden_shape[0],
|
||||
query_states.shape[1] * self.chunk_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
context_vectors = context_vectors[:, :q_time]
|
||||
|
||||
return context_vectors
|
||||
context_vectors = context_vectors[:, : hidden_shape[1]]
|
||||
return context_vectors, attn_weights
|
||||
|
||||
|
||||
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
||||
@ -1388,7 +1408,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
|
||||
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
|
||||
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
|
||||
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
|
||||
self.post_layer_scale = torch.tensor(self.config.residual_weight)
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
residual = audio_encodings
|
||||
@ -1412,7 +1432,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
|
||||
self.depthwise_conv1d = nn.Conv1d(
|
||||
in_channels=self.config.hidden_size,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.conf_conv_kernel_size,
|
||||
kernel_size=self.config.conv_kernel_size,
|
||||
stride=1,
|
||||
padding=0, # Manual causal padding
|
||||
groups=self.config.hidden_size, # Depthwise
|
||||
@ -1422,7 +1442,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
|
||||
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
||||
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
||||
|
||||
self.causal_padding = self.config.conf_conv_kernel_size - 1
|
||||
self.causal_padding = self.config.conv_kernel_size - 1
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
audio_encodings_residual = audio_encodings # Save for residual connection
|
||||
@ -1484,9 +1504,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
self.config = config
|
||||
|
||||
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
|
||||
self.conformer = nn.ModuleList(
|
||||
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
|
||||
)
|
||||
self.conformer = nn.ModuleList([Gemma3nAudioConformerBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
||||
@ -1535,10 +1553,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
for block in self.conformer:
|
||||
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
|
||||
|
||||
if self.config.conf_reduction_factor > 1:
|
||||
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
|
||||
if self.config.reduction_factor > 1:
|
||||
audio_encodings = audio_encodings[:, :: self.config.reduction_factor]
|
||||
# Reduce the mask as well
|
||||
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
|
||||
current_mask = current_mask[:, :: self.config.reduction_factor]
|
||||
|
||||
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
|
||||
return audio_encodings, current_mask
|
||||
|
Loading…
Reference in New Issue
Block a user