fix small lm3

This commit is contained in:
Arthur 2025-07-01 15:32:17 +02:00
parent 6a132a0799
commit 9fa5f266a1
3 changed files with 12 additions and 0 deletions

View File

@ -182,6 +182,7 @@ class SmolLM3Config(PretrainedConfig):
layer_types=None, layer_types=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -192,6 +193,7 @@ class SmolLM3Config(PretrainedConfig):
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.mlp_bias = mlp_bias
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers

View File

@ -246,6 +246,7 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
self.mlp = SmolLM3MLP(config) self.mlp = SmolLM3MLP(config)
self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
def forward( def forward(
self, self,

View File

@ -26,6 +26,7 @@ from ...processing_utils import Unpack
from ...utils import logging from ...utils import logging
from ..llama.modeling_llama import ( from ..llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM, LlamaForCausalLM,
LlamaForQuestionAnswering, LlamaForQuestionAnswering,
LlamaForSequenceClassification, LlamaForSequenceClassification,
@ -199,6 +200,7 @@ class SmolLM3Config(PretrainedConfig):
layer_types=None, layer_types=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -209,6 +211,7 @@ class SmolLM3Config(PretrainedConfig):
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.mlp_bias = mlp_bias
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
@ -315,6 +318,12 @@ class SmolLM3Attention(LlamaAttention):
return attn_output, attn_weights return attn_output, attn_weights
class SmolLM3DecoderLayer(LlamaDecoderLayer):
def __init__(self, config: SmolLM3Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]
class SmolLM3PreTrainedModel(LlamaPreTrainedModel): class SmolLM3PreTrainedModel(LlamaPreTrainedModel):
pass pass