[Moonshine] compute head_dim_padding at init (#35984)

compute head_dim_padding at init
This commit is contained in:
eustlb 2025-01-31 14:26:52 +01:00 committed by GitHub
parent d7188ba600
commit e6f4a4ebbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 38 deletions

View File

@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional, Tuple, Union
import numpy as np
@ -207,6 +206,14 @@ class MoonshineAttention(nn.Module):
)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
# Pad head dimension to the next specified multiple.
if self.config.pad_head_dim_to_multiple_of is not None:
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
self.head_dim_padding = target_head_dim - self.head_dim
else:
self.head_dim_padding = 0
def forward(
self,
hidden_states: torch.Tensor,
@ -276,21 +283,10 @@ class MoonshineAttention(nn.Module):
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
head_dim_padding = 0
if self.config.pad_head_dim_to_multiple_of is not None:
head_dim = query_states.shape[-1]
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
head_dim_padding = target_head_dim - head_dim
if head_dim_padding > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
if self.head_dim_padding > 0:
query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
attn_output, attn_weights = attention_interface(
self,
@ -304,9 +300,8 @@ class MoonshineAttention(nn.Module):
**kwargs,
)
# Remove head size padding.
if head_dim_padding > 0:
attn_output = attn_output[:, :, :, :-head_dim_padding]
if self.head_dim_padding > 0:
attn_output = attn_output[..., : -self.head_dim_padding]
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional, Tuple, Union
import torch
@ -302,6 +301,15 @@ class MoonshineAttention(GlmAttention):
config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
super().__init__(config, layer_idx)
self.is_causal = is_causal
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
# Pad head dimension to the next specified multiple.
if self.config.pad_head_dim_to_multiple_of is not None:
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
self.head_dim_padding = target_head_dim - self.head_dim
else:
self.head_dim_padding = 0
def forward(
self,
@ -372,21 +380,10 @@ class MoonshineAttention(GlmAttention):
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
head_dim_padding = 0
if self.config.pad_head_dim_to_multiple_of is not None:
head_dim = query_states.shape[-1]
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
head_dim_padding = target_head_dim - head_dim
if head_dim_padding > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
if self.head_dim_padding > 0:
query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
attn_output, attn_weights = attention_interface(
self,
@ -400,9 +397,8 @@ class MoonshineAttention(GlmAttention):
**kwargs,
)
# Remove head size padding.
if head_dim_padding > 0:
attn_output = attn_output[:, :, :, :-head_dim_padding]
if self.head_dim_padding > 0:
attn_output = attn_output[..., : -self.head_dim_padding]
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)