mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Fix: Mamba2 norm_before_gate
usage (#32686)
* mamba2 uses norm_before_gate=False * small nit * remove norm_before_gate flag and follow False path only
This commit is contained in:
parent
01c4fc455b
commit
c63a3d0f17
@ -83,8 +83,6 @@ class Mamba2Config(PretrainedConfig):
|
|||||||
Whether or not to rescale `out_proj` weights when initializing.
|
Whether or not to rescale `out_proj` weights when initializing.
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the cache should be used.
|
Whether or not the cache should be used.
|
||||||
norm_before_gate (`bool`, *optional*, defaults to `True`):
|
|
||||||
Option of cuda kernels -whether to normalize before the gate or not.
|
|
||||||
rms_norm (`bool`, *optional*, defaults to `True`):
|
rms_norm (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to use RMS norm or not.
|
Whether to use RMS norm or not.
|
||||||
chunk_size (`int`, *optional*, defaults to 256):
|
chunk_size (`int`, *optional*, defaults to 256):
|
||||||
@ -137,7 +135,6 @@ class Mamba2Config(PretrainedConfig):
|
|||||||
time_step_limit=(0.0, float("inf")),
|
time_step_limit=(0.0, float("inf")),
|
||||||
rescale_prenorm_residual=False,
|
rescale_prenorm_residual=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
norm_before_gate=True,
|
|
||||||
rms_norm=True,
|
rms_norm=True,
|
||||||
chunk_size=256,
|
chunk_size=256,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
@ -168,7 +165,6 @@ class Mamba2Config(PretrainedConfig):
|
|||||||
self.n_groups = n_groups
|
self.n_groups = n_groups
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.norm_before_gate = norm_before_gate
|
|
||||||
self.rms_norm = rms_norm
|
self.rms_norm = rms_norm
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
|
@ -208,7 +208,6 @@ class Mamba2Mixer(nn.Module):
|
|||||||
self.activation = config.hidden_act
|
self.activation = config.hidden_act
|
||||||
self.act = ACT2FN[config.hidden_act]
|
self.act = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
self.norm_before_gate = config.norm_before_gate
|
|
||||||
self.layer_norm_epsilon = config.layer_norm_epsilon
|
self.layer_norm_epsilon = config.layer_norm_epsilon
|
||||||
self.rms_norm = config.rms_norm
|
self.rms_norm = config.rms_norm
|
||||||
|
|
||||||
@ -347,7 +346,7 @@ class Mamba2Mixer(nn.Module):
|
|||||||
outproj_bias=self.out_proj.bias,
|
outproj_bias=self.out_proj.bias,
|
||||||
headdim=self.head_dim,
|
headdim=self.head_dim,
|
||||||
ngroups=self.n_groups,
|
ngroups=self.n_groups,
|
||||||
norm_before_gate=self.norm_before_gate,
|
norm_before_gate=False,
|
||||||
return_final_states=True,
|
return_final_states=True,
|
||||||
**dt_limit_kwargs,
|
**dt_limit_kwargs,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user