mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
BLTMLP identical to MllamTextMLP
This commit is contained in:
parent
dc09e71765
commit
91be87ec9b
@ -54,25 +54,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP
|
||||
class BLTMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
# Calculate intermediate_size based on actual hidden_size (not config.dim)
|
||||
base_dim = 4 * self.hidden_size
|
||||
intermediate_dim = int(2 * base_dim / 3)
|
||||
self.intermediate_size = config.multiple_of * ((intermediate_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
@ -259,7 +255,14 @@ class BLTTransformerLayer(nn.Module):
|
||||
|
||||
self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = BLTMLP(config=config)
|
||||
# Create a copy of config for MLP with pre-calculated dimensions
|
||||
mlp_config = type(config)(**config.__dict__)
|
||||
mlp_config.hidden_size = dim
|
||||
|
||||
# Calculate intermediate_size using the same logic as BLTMLP
|
||||
mlp_config.intermediate_size = multiple_of * (( int(8 * dim / 3) + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.mlp = BLTMLP(config=mlp_config)
|
||||
self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
||||
self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user