This commit is contained in:
Arthur 2025-07-01 14:19:38 +02:00
parent 7a0512a1f5
commit a13a98c6da
2 changed files with 4 additions and 23 deletions

View File

@ -84,18 +84,15 @@ class Glm4DecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -110,18 +107,12 @@ class Glm4DecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_self_attn_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
return hidden_states
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View File

@ -15,7 +15,6 @@
# limitations under the License.
from typing import Optional, Union
import torch.utils.checkpoint
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@ -56,18 +55,15 @@ class Glm4DecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -82,18 +78,12 @@ class Glm4DecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_self_attn_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
return hidden_states
class Glm4Attention(GlmAttention):