mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
BLTTransformerLayer config
This commit is contained in:
parent
d2e6902460
commit
9d318707e4
@ -37,6 +37,44 @@ class PatchingModeEnum(str, Enum):
|
||||
byte = "byte"
|
||||
|
||||
|
||||
class TransformersLayerConfig:
|
||||
"""
|
||||
Configuration class for BLT Transformer layers, providing all necessary parameters
|
||||
for attention, MLP, and normalization components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
intermediate_size: int,
|
||||
norm_eps: float,
|
||||
dropout: float,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
rope_scaling: dict,
|
||||
hidden_act: str = "silu",
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.intermediate_size = intermediate_size
|
||||
self.norm_eps = norm_eps
|
||||
self.dropout = dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
# Add any additional kwargs as attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class BLTPatcherConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class for the BLT Patcher/Entropy model component.
|
||||
@ -578,6 +616,63 @@ class BLTConfig(PretrainedConfig):
|
||||
# Note: Each component uses its own hidden dimension, not the main dim
|
||||
self.intermediate_size = None # Will be calculated per component
|
||||
|
||||
# layer configurations as dictionaries (needed to be JSON serializable!)
|
||||
self._encoder_layer_config_dict = {
|
||||
"hidden_size": self.dim_local_encoder,
|
||||
"num_attention_heads": self.n_heads_local_encoder,
|
||||
"num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_encoder,
|
||||
"head_dim": self.dim_local_encoder // self.n_heads_local_encoder,
|
||||
"intermediate_size": self.multiple_of * ((int(8 * self.dim_local_encoder / 3) + self.multiple_of - 1) // self.multiple_of),
|
||||
"norm_eps": self.norm_eps,
|
||||
"dropout": self.dropout,
|
||||
"max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen,
|
||||
"rope_theta": self.rope_theta,
|
||||
"rope_scaling": self.rope_scaling,
|
||||
"hidden_act": self.hidden_act,
|
||||
}
|
||||
|
||||
self._decoder_layer_config_dict = {
|
||||
"hidden_size": self.dim_local_decoder,
|
||||
"num_attention_heads": self.n_heads_local_decoder,
|
||||
"num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_decoder,
|
||||
"head_dim": self.dim_local_decoder // self.n_heads_local_decoder,
|
||||
"intermediate_size": self.multiple_of * ((int(8 * self.dim_local_decoder / 3) + self.multiple_of - 1) // self.multiple_of),
|
||||
"norm_eps": self.norm_eps,
|
||||
"dropout": self.dropout,
|
||||
"max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen,
|
||||
"rope_theta": self.rope_theta,
|
||||
"rope_scaling": self.rope_scaling,
|
||||
"hidden_act": self.hidden_act,
|
||||
}
|
||||
|
||||
self._global_layer_config_dict = {
|
||||
"hidden_size": self.dim_global,
|
||||
"num_attention_heads": self.n_heads_global,
|
||||
"num_key_value_heads": getattr(self, 'n_kv_heads_global', None) or self.n_heads_global,
|
||||
"head_dim": self.dim_global // self.n_heads_global,
|
||||
"intermediate_size": self.multiple_of * ((int(8 * self.dim_global / 3) + self.multiple_of - 1) // self.multiple_of),
|
||||
"norm_eps": self.norm_eps,
|
||||
"dropout": self.dropout,
|
||||
"max_position_embeddings": self.max_seqlen,
|
||||
"rope_theta": self.rope_theta,
|
||||
"rope_scaling": self.rope_scaling,
|
||||
"hidden_act": self.hidden_act,
|
||||
}
|
||||
|
||||
self._patcher_layer_config_dict = {
|
||||
"hidden_size": self.patcher_config.dim,
|
||||
"num_attention_heads": self.patcher_config.n_heads,
|
||||
"num_key_value_heads": getattr(self.patcher_config, 'n_kv_heads', None) or self.patcher_config.n_heads,
|
||||
"head_dim": self.patcher_config.dim // self.patcher_config.n_heads,
|
||||
"intermediate_size": self.patcher_config.multiple_of * ((int(8 * self.patcher_config.dim / 3) + self.patcher_config.multiple_of - 1) // self.patcher_config.multiple_of),
|
||||
"norm_eps": self.patcher_config.norm_eps,
|
||||
"dropout": self.patcher_config.dropout,
|
||||
"max_position_embeddings": self.patcher_config.max_seqlen,
|
||||
"rope_theta": self.patcher_config.rope_theta,
|
||||
"rope_scaling": self.patcher_config.rope_scaling,
|
||||
"hidden_act": self.hidden_act,
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -585,6 +680,21 @@ class BLTConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def encoder_layer_config(self) -> TransformersLayerConfig:
|
||||
return TransformersLayerConfig(**self._encoder_layer_config_dict)
|
||||
|
||||
@property
|
||||
def decoder_layer_config(self) -> TransformersLayerConfig:
|
||||
return TransformersLayerConfig(**self._decoder_layer_config_dict)
|
||||
|
||||
@property
|
||||
def global_layer_config(self) -> TransformersLayerConfig:
|
||||
return TransformersLayerConfig(**self._global_layer_config_dict)
|
||||
|
||||
@property
|
||||
def patcher_layer_config(self) -> TransformersLayerConfig:
|
||||
return TransformersLayerConfig(**self._patcher_layer_config_dict)
|
||||
|
||||
@property
|
||||
def encoder_dim_token_emb(self):
|
||||
@ -648,6 +758,5 @@ class BLTConfig(PretrainedConfig):
|
||||
else: # DISABLED
|
||||
return 1.0
|
||||
|
||||
|
||||
__all__ = ["BLTConfig", "BLTPatcherConfig", "InitStdFactor", "PatchingModeEnum"]
|
||||
__all__ = ["BLTConfig", "BLTPatcherConfig", "TransformersLayerConfig", "InitStdFactor", "PatchingModeEnum"]
|
||||
|
||||
|
@ -31,6 +31,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from .configuration_blt import (
|
||||
BLTConfig,
|
||||
PatchingModeEnum,
|
||||
TransformersLayerConfig,
|
||||
)
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
@ -153,7 +154,7 @@ class BLTRMSNorm(nn.Module):
|
||||
|
||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
|
||||
class BLTSelfAttention(nn.Module):
|
||||
def __init__(self, config: BLTConfig, layer_idx: int):
|
||||
def __init__(self, config: TransformersLayerConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -233,9 +234,10 @@ class BLTSelfAttention(nn.Module):
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
# Copied from transformers.models.llama.modeling_mllama.MllamaSelfAttentionDecoderLayer
|
||||
|
||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer
|
||||
class BLTTransformerLayer(nn.Module):
|
||||
def __init__(self, config: BLTConfig, layer_idx: int):
|
||||
def __init__(self, config: TransformersLayerConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
@ -480,7 +482,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) ->
|
||||
|
||||
|
||||
class BLTRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: BLTConfig, device=None):
|
||||
def __init__(self, config: TransformersLayerConfig, device=None):
|
||||
super().__init__()
|
||||
self.rope_type = config.rope_scaling["rope_type"]
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
@ -528,19 +530,9 @@ class BLTLocalEncoder(nn.Module):
|
||||
self.norm_eps = config.norm_eps
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
# Set up config for layers with proper dimensions
|
||||
encoder_config = config
|
||||
encoder_config.hidden_size = self.dim_local_encoder
|
||||
encoder_config.num_attention_heads = self.n_heads_local_encoder
|
||||
encoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_encoder
|
||||
encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder
|
||||
encoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_encoder / 3) + config.multiple_of - 1) // config.multiple_of)
|
||||
self.layers = nn.ModuleList([BLTTransformerLayer(config.encoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)])
|
||||
|
||||
self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)])
|
||||
|
||||
# Set up config for rotary embedding
|
||||
encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=encoder_config)
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=config.encoder_layer_config)
|
||||
|
||||
self.token_embedding_projection = (
|
||||
nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False)
|
||||
@ -552,7 +544,6 @@ class BLTLocalEncoder(nn.Module):
|
||||
|
||||
self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder)
|
||||
|
||||
# Initialize cross attention layers only if cross attention is enabled
|
||||
self.cross_attn_layers = None
|
||||
if self.cross_attn_encoder and self.cross_attn_nheads is not None:
|
||||
self.cross_attn_layers = torch.nn.ModuleList()
|
||||
@ -680,19 +671,9 @@ class BLTLocalDecoder(nn.Module):
|
||||
self.cross_attn_k = config.cross_attn_k
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
# Set up config for layers with proper dimensions
|
||||
decoder_config = config
|
||||
decoder_config.hidden_size = self.dim_local_decoder
|
||||
decoder_config.num_attention_heads = self.n_heads_local_decoder
|
||||
decoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_decoder
|
||||
decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder
|
||||
decoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_decoder / 3) + config.multiple_of - 1) // config.multiple_of)
|
||||
self.layers = nn.ModuleList([BLTTransformerLayer(config.decoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)])
|
||||
|
||||
self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)])
|
||||
|
||||
decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen
|
||||
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=decoder_config)
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=config.decoder_layer_config)
|
||||
|
||||
self.token_embedding_projection = (
|
||||
nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False)
|
||||
@ -704,7 +685,6 @@ class BLTLocalDecoder(nn.Module):
|
||||
|
||||
self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps)
|
||||
|
||||
# Initialize cross attention layers only if cross attention is enabled
|
||||
self.cross_attn_layers = None
|
||||
if self.cross_attn_decoder and self.cross_attn_nheads is not None:
|
||||
self.cross_attn_layers = torch.nn.ModuleList()
|
||||
@ -789,7 +769,7 @@ class BLTLocalDecoder(nn.Module):
|
||||
logits = self.lm_head(self.norm(hidden_states))
|
||||
return logits, cache
|
||||
|
||||
# Modified from transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention
|
||||
|
||||
class BLTCrossAttention(nn.Module):
|
||||
"""Cross-attention module for BLT, following transformers style"""
|
||||
|
||||
@ -898,25 +878,16 @@ class BLTGlobalTransformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# Extract config values to instance attributes
|
||||
self.dim_global = config.dim_global
|
||||
self.n_heads_global = config.n_heads_global
|
||||
self.n_layers_global = config.n_layers_global
|
||||
self.dropout = config.dropout
|
||||
|
||||
# Set up config for layers with proper dimensions
|
||||
global_config = config
|
||||
global_config.hidden_size = self.dim_global
|
||||
global_config.num_attention_heads = self.n_heads_global
|
||||
global_config.num_key_value_heads = getattr(config, 'n_kv_heads_global', None) or self.n_heads_global
|
||||
global_config.head_dim = self.dim_global // self.n_heads_global
|
||||
global_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_global / 3) + config.multiple_of - 1) // config.multiple_of)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for layer_idx in range(self.n_layers_global):
|
||||
self.layers.append(BLTTransformerLayer(global_config, layer_idx))
|
||||
self.layers.append(BLTTransformerLayer(config.global_layer_config, layer_idx))
|
||||
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=global_config)
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=config.global_layer_config)
|
||||
|
||||
self.token_embedding_projection = None
|
||||
if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global:
|
||||
@ -1292,17 +1263,8 @@ class BLTPatcher(BLTPreTrainedModel):
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=self.config)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
# Set up config for layers with proper dimensions
|
||||
patcher_config = self.config
|
||||
patcher_config.hidden_size = self.config.dim
|
||||
patcher_config.num_attention_heads = self.config.n_heads
|
||||
patcher_config.num_key_value_heads = getattr(self.config, 'n_kv_heads', None) or self.config.n_heads
|
||||
patcher_config.head_dim = self.config.dim // self.config.n_heads
|
||||
patcher_config.intermediate_size = self.config.multiple_of * ((int(8 * self.config.dim / 3) + self.config.multiple_of - 1) // self.config.multiple_of)
|
||||
|
||||
for layer_idx in range(self.config.n_layers):
|
||||
self.layers.append(BLTTransformerLayer(patcher_config, layer_idx))
|
||||
|
||||
self.layers.append(BLTTransformerLayer(config.patcher_layer_config, layer_idx))
|
||||
|
||||
self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user