mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Moonshine] compute head_dim_padding at init (#35984)
compute head_dim_padding at init
This commit is contained in:
parent
d7188ba600
commit
e6f4a4ebbf
@ -18,7 +18,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -207,6 +206,14 @@ class MoonshineAttention(nn.Module):
|
||||
)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
|
||||
# Pad head dimension to the next specified multiple.
|
||||
if self.config.pad_head_dim_to_multiple_of is not None:
|
||||
target_multiple = self.config.pad_head_dim_to_multiple_of
|
||||
target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
|
||||
self.head_dim_padding = target_head_dim - self.head_dim
|
||||
else:
|
||||
self.head_dim_padding = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -276,21 +283,10 @@ class MoonshineAttention(nn.Module):
|
||||
|
||||
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
||||
|
||||
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
|
||||
head_dim_padding = 0
|
||||
if self.config.pad_head_dim_to_multiple_of is not None:
|
||||
head_dim = query_states.shape[-1]
|
||||
target_multiple = self.config.pad_head_dim_to_multiple_of
|
||||
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
|
||||
head_dim_padding = target_head_dim - head_dim
|
||||
if head_dim_padding > 0:
|
||||
# Ensure scaling is correct even with padding.
|
||||
if self.scaling is None:
|
||||
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
|
||||
|
||||
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
|
||||
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
|
||||
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
|
||||
if self.head_dim_padding > 0:
|
||||
query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
|
||||
key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
|
||||
value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
@ -304,9 +300,8 @@ class MoonshineAttention(nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Remove head size padding.
|
||||
if head_dim_padding > 0:
|
||||
attn_output = attn_output[:, :, :, :-head_dim_padding]
|
||||
if self.head_dim_padding > 0:
|
||||
attn_output = attn_output[..., : -self.head_dim_padding]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -302,6 +301,15 @@ class MoonshineAttention(GlmAttention):
|
||||
config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
|
||||
super().__init__(config, layer_idx)
|
||||
self.is_causal = is_causal
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
|
||||
# Pad head dimension to the next specified multiple.
|
||||
if self.config.pad_head_dim_to_multiple_of is not None:
|
||||
target_multiple = self.config.pad_head_dim_to_multiple_of
|
||||
target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
|
||||
self.head_dim_padding = target_head_dim - self.head_dim
|
||||
else:
|
||||
self.head_dim_padding = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -372,21 +380,10 @@ class MoonshineAttention(GlmAttention):
|
||||
|
||||
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
||||
|
||||
# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
|
||||
head_dim_padding = 0
|
||||
if self.config.pad_head_dim_to_multiple_of is not None:
|
||||
head_dim = query_states.shape[-1]
|
||||
target_multiple = self.config.pad_head_dim_to_multiple_of
|
||||
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
|
||||
head_dim_padding = target_head_dim - head_dim
|
||||
if head_dim_padding > 0:
|
||||
# Ensure scaling is correct even with padding.
|
||||
if self.scaling is None:
|
||||
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])
|
||||
|
||||
query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
|
||||
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
|
||||
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))
|
||||
if self.head_dim_padding > 0:
|
||||
query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
|
||||
key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
|
||||
value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
@ -400,9 +397,8 @@ class MoonshineAttention(GlmAttention):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Remove head size padding.
|
||||
if head_dim_padding > 0:
|
||||
attn_output = attn_output[:, :, :, :-head_dim_padding]
|
||||
if self.head_dim_padding > 0:
|
||||
attn_output = attn_output[..., : -self.head_dim_padding]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
Loading…
Reference in New Issue
Block a user