diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 96285262514..d82f715fbd5 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, Optional, Tuple, Union import numpy as np @@ -207,6 +206,14 @@ class MoonshineAttention(nn.Module): ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + # Pad head dimension to the next specified multiple. + if self.config.pad_head_dim_to_multiple_of is not None: + target_multiple = self.config.pad_head_dim_to_multiple_of + target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple) + self.head_dim_padding = target_head_dim - self.head_dim + else: + self.head_dim_padding = 0 + def forward( self, hidden_states: torch.Tensor, @@ -276,21 +283,10 @@ class MoonshineAttention(nn.Module): is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False - # Pad head size dimension to next specified multiple. Q K and V always have equal head sizes. - head_dim_padding = 0 - if self.config.pad_head_dim_to_multiple_of is not None: - head_dim = query_states.shape[-1] - target_multiple = self.config.pad_head_dim_to_multiple_of - target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple) - head_dim_padding = target_head_dim - head_dim - if head_dim_padding > 0: - # Ensure scaling is correct even with padding. - if self.scaling is None: - self.scaling = 1.0 / math.sqrt(query_states.shape[-1]) - - query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding)) - key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding)) - value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding)) + if self.head_dim_padding > 0: + query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding)) + key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding)) + value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding)) attn_output, attn_weights = attention_interface( self, @@ -304,9 +300,8 @@ class MoonshineAttention(nn.Module): **kwargs, ) - # Remove head size padding. - if head_dim_padding > 0: - attn_output = attn_output[:, :, :, :-head_dim_padding] + if self.head_dim_padding > 0: + attn_output = attn_output[..., : -self.head_dim_padding] attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index a78b153725d..24fa4f0a1ef 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, Optional, Tuple, Union import torch @@ -302,6 +301,15 @@ class MoonshineAttention(GlmAttention): config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads}) super().__init__(config, layer_idx) self.is_causal = is_causal + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + # Pad head dimension to the next specified multiple. + if self.config.pad_head_dim_to_multiple_of is not None: + target_multiple = self.config.pad_head_dim_to_multiple_of + target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple) + self.head_dim_padding = target_head_dim - self.head_dim + else: + self.head_dim_padding = 0 def forward( self, @@ -372,21 +380,10 @@ class MoonshineAttention(GlmAttention): is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False - # Pad head size dimension to next specified multiple. Q K and V always have equal head sizes. - head_dim_padding = 0 - if self.config.pad_head_dim_to_multiple_of is not None: - head_dim = query_states.shape[-1] - target_multiple = self.config.pad_head_dim_to_multiple_of - target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple) - head_dim_padding = target_head_dim - head_dim - if head_dim_padding > 0: - # Ensure scaling is correct even with padding. - if self.scaling is None: - self.scaling = 1.0 / math.sqrt(query_states.shape[-1]) - - query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding)) - key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding)) - value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding)) + if self.head_dim_padding > 0: + query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding)) + key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding)) + value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding)) attn_output, attn_weights = attention_interface( self, @@ -400,9 +397,8 @@ class MoonshineAttention(GlmAttention): **kwargs, ) - # Remove head size padding. - if head_dim_padding > 0: - attn_output = attn_output[:, :, :, :-head_dim_padding] + if self.head_dim_padding > 0: + attn_output = attn_output[..., : -self.head_dim_padding] attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output)