mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix small lm3
This commit is contained in:
parent
6a132a0799
commit
9fa5f266a1
@ -182,6 +182,7 @@ class SmolLM3Config(PretrainedConfig):
|
|||||||
layer_types=None,
|
layer_types=None,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
|
mlp_bias=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -192,6 +193,7 @@ class SmolLM3Config(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
@ -246,6 +246,7 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
|
|||||||
self.mlp = SmolLM3MLP(config)
|
self.mlp = SmolLM3MLP(config)
|
||||||
self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.attention_type = config.layer_types[layer_idx]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -26,6 +26,7 @@ from ...processing_utils import Unpack
|
|||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..llama.modeling_llama import (
|
from ..llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
|
LlamaDecoderLayer,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForQuestionAnswering,
|
LlamaForQuestionAnswering,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
@ -199,6 +200,7 @@ class SmolLM3Config(PretrainedConfig):
|
|||||||
layer_types=None,
|
layer_types=None,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
|
mlp_bias=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -209,6 +211,7 @@ class SmolLM3Config(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
@ -315,6 +318,12 @@ class SmolLM3Attention(LlamaAttention):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class SmolLM3DecoderLayer(LlamaDecoderLayer):
|
||||||
|
def __init__(self, config: SmolLM3Config, layer_idx: int):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.attention_type = config.layer_types[layer_idx]
|
||||||
|
|
||||||
|
|
||||||
class SmolLM3PreTrainedModel(LlamaPreTrainedModel):
|
class SmolLM3PreTrainedModel(LlamaPreTrainedModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user