Fix: Mamba2 norm_before_gate usage (#32686)

* mamba2 uses norm_before_gate=False

* small nit

* remove norm_before_gate flag and follow False path only
This commit is contained in:
Anton Vlasjuk 2024-08-20 19:47:34 +02:00 committed by GitHub
parent 01c4fc455b
commit c63a3d0f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1 additions and 6 deletions

View File

@ -83,8 +83,6 @@ class Mamba2Config(PretrainedConfig):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
norm_before_gate (`bool`, *optional*, defaults to `True`):
Option of cuda kernels -whether to normalize before the gate or not.
rms_norm (`bool`, *optional*, defaults to `True`):
Whether to use RMS norm or not.
chunk_size (`int`, *optional*, defaults to 256):
@ -137,7 +135,6 @@ class Mamba2Config(PretrainedConfig):
time_step_limit=(0.0, float("inf")),
rescale_prenorm_residual=False,
use_cache=True,
norm_before_gate=True,
rms_norm=True,
chunk_size=256,
tie_word_embeddings=False,
@ -168,7 +165,6 @@ class Mamba2Config(PretrainedConfig):
self.n_groups = n_groups
self.num_heads = num_heads
self.head_dim = head_dim
self.norm_before_gate = norm_before_gate
self.rms_norm = rms_norm
self.state_size = state_size
self.chunk_size = chunk_size

View File

@ -208,7 +208,6 @@ class Mamba2Mixer(nn.Module):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.norm_before_gate = config.norm_before_gate
self.layer_norm_epsilon = config.layer_norm_epsilon
self.rms_norm = config.rms_norm
@ -347,7 +346,7 @@ class Mamba2Mixer(nn.Module):
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=self.norm_before_gate,
norm_before_gate=False,
return_final_states=True,
**dt_limit_kwargs,
)