🚨 [Mistral and friends] Update MLP (#31057)

Update MLP
This commit is contained in:
NielsRogge 2024-06-03 14:57:07 +02:00 committed by GitHub
parent d475f76745
commit 70c8713872
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 12 deletions

View File

@ -1001,7 +1001,6 @@ class JambaMambaMixer(nn.Module):
class JambaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@ -1009,8 +1008,8 @@ class JambaMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba

View File

@ -160,7 +160,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class MistralMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@ -168,8 +167,8 @@ class MistralMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv

View File

@ -173,7 +173,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@ -181,8 +180,8 @@ class Qwen2MLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv

View File

@ -197,7 +197,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
class StableLmMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@ -205,8 +204,8 @@ class StableLmMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class StableLmLayerNormPerHead(nn.Module):