fix small lm3

This commit is contained in:
Arthur 2025-07-01 15:32:17 +02:00
parent 6a132a0799
commit 9fa5f266a1
3 changed files with 12 additions and 0 deletions

View File

@ -182,6 +182,7 @@ class SmolLM3Config(PretrainedConfig):
layer_types=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
super().__init__(
@ -192,6 +193,7 @@ class SmolLM3Config(PretrainedConfig):
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.mlp_bias = mlp_bias
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers

View File

@ -246,6 +246,7 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
self.mlp = SmolLM3MLP(config)
self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
def forward(
self,

View File

@ -26,6 +26,7 @@ from ...processing_utils import Unpack
from ...utils import logging
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
@ -199,6 +200,7 @@ class SmolLM3Config(PretrainedConfig):
layer_types=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
super().__init__(
@ -209,6 +211,7 @@ class SmolLM3Config(PretrainedConfig):
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.mlp_bias = mlp_bias
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
@ -315,6 +318,12 @@ class SmolLM3Attention(LlamaAttention):
return attn_output, attn_weights
class SmolLM3DecoderLayer(LlamaDecoderLayer):
def __init__(self, config: SmolLM3Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]
class SmolLM3PreTrainedModel(LlamaPreTrainedModel):
pass