mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
BLTMLP identical to MllamTextMLP
This commit is contained in:
parent
dc09e71765
commit
91be87ec9b
@ -54,25 +54,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP
|
||||||
class BLTMLP(nn.Module):
|
class BLTMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
# Calculate intermediate_size based on actual hidden_size (not config.dim)
|
|
||||||
base_dim = 4 * self.hidden_size
|
|
||||||
intermediate_dim = int(2 * base_dim / 3)
|
|
||||||
self.intermediate_size = config.multiple_of * ((intermediate_dim + config.multiple_of - 1) // config.multiple_of)
|
|
||||||
|
|
||||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
||||||
def eager_attention_forward(
|
def eager_attention_forward(
|
||||||
@ -259,7 +255,14 @@ class BLTTransformerLayer(nn.Module):
|
|||||||
|
|
||||||
self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
|
self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
self.mlp = BLTMLP(config=config)
|
# Create a copy of config for MLP with pre-calculated dimensions
|
||||||
|
mlp_config = type(config)(**config.__dict__)
|
||||||
|
mlp_config.hidden_size = dim
|
||||||
|
|
||||||
|
# Calculate intermediate_size using the same logic as BLTMLP
|
||||||
|
mlp_config.intermediate_size = multiple_of * (( int(8 * dim / 3) + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.mlp = BLTMLP(config=mlp_config)
|
||||||
self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
||||||
self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user