mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
adding BLTRMSNorm like Llama
This commit is contained in:
parent
9400c79db7
commit
904da82c32
@ -3,7 +3,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.models.blt_wip.modeling_blt_modellike import BLTModel
|
||||
from transformers.models.blt_wip.modeling_blt import BLTModel
|
||||
from transformers.models.blt_wip.tokenization_blt import BLTTokenizer
|
||||
|
||||
|
||||
|
@ -135,6 +135,26 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
return q_rot.type_as(q), k_rot.type_as(k)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||
class BLTRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
BLTRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
|
||||
class BLTSelfAttention(nn.Module):
|
||||
@ -241,8 +261,8 @@ class BLTTransformerLayer(nn.Module):
|
||||
self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = BLTMLP(config=config)
|
||||
self.input_layernorm = RMSNorm(dim, eps=norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(dim, eps=norm_eps)
|
||||
self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
||||
self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -411,13 +431,6 @@ def _prepare_patch_cross_attention_mask(
|
||||
|
||||
return cross_attention_mask, full_text_row_masked_out_mask
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
|
||||
if max_patch_length is None:
|
||||
return patch_lengths
|
||||
@ -659,7 +672,7 @@ class BLTLocalDecoder(nn.Module):
|
||||
|
||||
self.patch_embedding_projection = self._create_patch_projection(config)
|
||||
|
||||
self.norm = RMSNorm(self.dim_local_decoder, eps=self.norm_eps)
|
||||
self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps)
|
||||
|
||||
# Initialize cross attention layers only if cross attention is enabled
|
||||
self.cross_attn_layers = None
|
||||
@ -1251,7 +1264,7 @@ class BLTPatcher(BLTPreTrainedModel):
|
||||
|
||||
self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
|
||||
|
||||
self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps)
|
||||
self.norm = BLTRMSNorm(self.config.dim, eps=self.config.norm_eps)
|
||||
|
||||
self.lm_head = nn.Linear(
|
||||
self.config.dim,
|
||||
|
Loading…
Reference in New Issue
Block a user