fix other models as well!

This commit is contained in:
Arthur 2025-06-30 14:55:01 +02:00
parent 8c66f4d0bb
commit 3caf7d76a0
20 changed files with 43 additions and 54 deletions

View File

@ -285,9 +285,8 @@ class ArceeDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -296,7 +295,7 @@ class ArceeDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected
@ -645,6 +644,7 @@ class ArceeForQuestionAnswering(ArceePreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -601,9 +601,8 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -612,7 +611,7 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -246,9 +246,8 @@ class BitNetDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -257,7 +256,7 @@ class BitNetDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -371,9 +371,8 @@ class CsmDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -382,7 +381,7 @@ class CsmDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -466,9 +466,8 @@ class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -477,7 +476,7 @@ class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -517,9 +517,8 @@ class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -528,7 +527,7 @@ class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected
@ -929,6 +928,7 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -388,9 +388,8 @@ class Dots1DecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -399,7 +398,7 @@ class Dots1DecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -283,9 +283,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -294,7 +293,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -299,9 +299,8 @@ class GlmDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -310,7 +309,7 @@ class GlmDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -332,9 +332,8 @@ class GPTNeoXDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -343,7 +342,7 @@ class GPTNeoXDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -284,9 +284,8 @@ class HeliumDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -295,7 +294,7 @@ class HeliumDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -226,9 +226,8 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -237,7 +236,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -359,9 +359,8 @@ class MoonshineEncoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -370,7 +369,7 @@ class MoonshineEncoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -1110,6 +1110,7 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -229,9 +229,8 @@ class OlmoDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -240,7 +239,7 @@ class OlmoDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -230,9 +230,8 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -241,7 +240,7 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected
@ -758,6 +757,7 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -256,9 +256,8 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -267,7 +266,7 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected
@ -784,6 +783,7 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -1000,6 +1000,7 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -259,9 +259,8 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -270,7 +269,7 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected
@ -787,6 +786,7 @@ class SmolLM3ForQuestionAnswering(SmolLM3PreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs.last_hidden_state

View File

@ -229,9 +229,8 @@ class Starcoder2DecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -240,7 +239,7 @@ class Starcoder2DecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected