adding BLTRMSNorm like Llama

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-23 13:15:10 +00:00 committed by ita.zaporozhets@huggingface.co
parent 9400c79db7
commit 904da82c32
2 changed files with 25 additions and 12 deletions

View File

@ -3,7 +3,7 @@ import os
import torch
from transformers.models.blt_wip.modeling_blt_modellike import BLTModel
from transformers.models.blt_wip.modeling_blt import BLTModel
from transformers.models.blt_wip.tokenization_blt import BLTTokenizer

View File

@ -135,6 +135,26 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_rot.type_as(q), k_rot.type_as(k)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class BLTRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
BLTRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
class BLTSelfAttention(nn.Module):
@ -241,8 +261,8 @@ class BLTTransformerLayer(nn.Module):
self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
self.mlp = BLTMLP(config=config)
self.input_layernorm = RMSNorm(dim, eps=norm_eps)
self.post_attention_layernorm = RMSNorm(dim, eps=norm_eps)
self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps)
self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps)
def forward(
self,
@ -411,13 +431,6 @@ def _prepare_patch_cross_attention_mask(
return cross_attention_mask, full_text_row_masked_out_mask
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
if max_patch_length is None:
return patch_lengths
@ -659,7 +672,7 @@ class BLTLocalDecoder(nn.Module):
self.patch_embedding_projection = self._create_patch_projection(config)
self.norm = RMSNorm(self.dim_local_decoder, eps=self.norm_eps)
self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps)
# Initialize cross attention layers only if cross attention is enabled
self.cross_attn_layers = None
@ -1251,7 +1264,7 @@ class BLTPatcher(BLTPreTrainedModel):
self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps)
self.norm = BLTRMSNorm(self.config.dim, eps=self.config.norm_eps)
self.lm_head = nn.Linear(
self.config.dim,