BLTMLP identical to MllamTextMLP

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-23 13:26:59 +00:00 committed by ita.zaporozhets@huggingface.co
parent dc09e71765
commit 91be87ec9b

View File

@ -54,25 +54,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP
class BLTMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
# Calculate intermediate_size based on actual hidden_size (not config.dim)
base_dim = 4 * self.hidden_size
intermediate_dim = int(2 * base_dim / 3)
self.intermediate_size = config.multiple_of * ((intermediate_dim + config.multiple_of - 1) // config.multiple_of)
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
def eager_attention_forward(
@ -259,7 +255,14 @@ class BLTTransformerLayer(nn.Module):
self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
self.mlp = BLTMLP(config=config)
# Create a copy of config for MLP with pre-calculated dimensions
mlp_config = type(config)(**config.__dict__)
mlp_config.hidden_size = dim
# Calculate intermediate_size using the same logic as BLTMLP
mlp_config.intermediate_size = multiple_of * (( int(8 * dim / 3) + multiple_of - 1) // multiple_of)
self.mlp = BLTMLP(config=mlp_config)
self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps)
self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps)