BLTTransformerLayer config

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-23 15:06:42 +00:00 committed by ita.zaporozhets@huggingface.co
parent d2e6902460
commit 9d318707e4
2 changed files with 125 additions and 54 deletions

View File

@ -37,6 +37,44 @@ class PatchingModeEnum(str, Enum):
byte = "byte"
class TransformersLayerConfig:
"""
Configuration class for BLT Transformer layers, providing all necessary parameters
for attention, MLP, and normalization components.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
intermediate_size: int,
norm_eps: float,
dropout: float,
max_position_embeddings: int,
rope_theta: float,
rope_scaling: dict,
hidden_act: str = "silu",
**kwargs,
):
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.norm_eps = norm_eps
self.dropout = dropout
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.hidden_act = hidden_act
# Add any additional kwargs as attributes
for key, value in kwargs.items():
setattr(self, key, value)
class BLTPatcherConfig(PretrainedConfig):
r"""
Configuration class for the BLT Patcher/Entropy model component.
@ -578,6 +616,63 @@ class BLTConfig(PretrainedConfig):
# Note: Each component uses its own hidden dimension, not the main dim
self.intermediate_size = None # Will be calculated per component
# layer configurations as dictionaries (needed to be JSON serializable!)
self._encoder_layer_config_dict = {
"hidden_size": self.dim_local_encoder,
"num_attention_heads": self.n_heads_local_encoder,
"num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_encoder,
"head_dim": self.dim_local_encoder // self.n_heads_local_encoder,
"intermediate_size": self.multiple_of * ((int(8 * self.dim_local_encoder / 3) + self.multiple_of - 1) // self.multiple_of),
"norm_eps": self.norm_eps,
"dropout": self.dropout,
"max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen,
"rope_theta": self.rope_theta,
"rope_scaling": self.rope_scaling,
"hidden_act": self.hidden_act,
}
self._decoder_layer_config_dict = {
"hidden_size": self.dim_local_decoder,
"num_attention_heads": self.n_heads_local_decoder,
"num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_decoder,
"head_dim": self.dim_local_decoder // self.n_heads_local_decoder,
"intermediate_size": self.multiple_of * ((int(8 * self.dim_local_decoder / 3) + self.multiple_of - 1) // self.multiple_of),
"norm_eps": self.norm_eps,
"dropout": self.dropout,
"max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen,
"rope_theta": self.rope_theta,
"rope_scaling": self.rope_scaling,
"hidden_act": self.hidden_act,
}
self._global_layer_config_dict = {
"hidden_size": self.dim_global,
"num_attention_heads": self.n_heads_global,
"num_key_value_heads": getattr(self, 'n_kv_heads_global', None) or self.n_heads_global,
"head_dim": self.dim_global // self.n_heads_global,
"intermediate_size": self.multiple_of * ((int(8 * self.dim_global / 3) + self.multiple_of - 1) // self.multiple_of),
"norm_eps": self.norm_eps,
"dropout": self.dropout,
"max_position_embeddings": self.max_seqlen,
"rope_theta": self.rope_theta,
"rope_scaling": self.rope_scaling,
"hidden_act": self.hidden_act,
}
self._patcher_layer_config_dict = {
"hidden_size": self.patcher_config.dim,
"num_attention_heads": self.patcher_config.n_heads,
"num_key_value_heads": getattr(self.patcher_config, 'n_kv_heads', None) or self.patcher_config.n_heads,
"head_dim": self.patcher_config.dim // self.patcher_config.n_heads,
"intermediate_size": self.patcher_config.multiple_of * ((int(8 * self.patcher_config.dim / 3) + self.patcher_config.multiple_of - 1) // self.patcher_config.multiple_of),
"norm_eps": self.patcher_config.norm_eps,
"dropout": self.patcher_config.dropout,
"max_position_embeddings": self.patcher_config.max_seqlen,
"rope_theta": self.patcher_config.rope_theta,
"rope_scaling": self.patcher_config.rope_scaling,
"hidden_act": self.hidden_act,
}
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
@ -585,6 +680,21 @@ class BLTConfig(PretrainedConfig):
**kwargs,
)
@property
def encoder_layer_config(self) -> TransformersLayerConfig:
return TransformersLayerConfig(**self._encoder_layer_config_dict)
@property
def decoder_layer_config(self) -> TransformersLayerConfig:
return TransformersLayerConfig(**self._decoder_layer_config_dict)
@property
def global_layer_config(self) -> TransformersLayerConfig:
return TransformersLayerConfig(**self._global_layer_config_dict)
@property
def patcher_layer_config(self) -> TransformersLayerConfig:
return TransformersLayerConfig(**self._patcher_layer_config_dict)
@property
def encoder_dim_token_emb(self):
@ -648,6 +758,5 @@ class BLTConfig(PretrainedConfig):
else: # DISABLED
return 1.0
__all__ = ["BLTConfig", "BLTPatcherConfig", "InitStdFactor", "PatchingModeEnum"]
__all__ = ["BLTConfig", "BLTPatcherConfig", "TransformersLayerConfig", "InitStdFactor", "PatchingModeEnum"]

View File

@ -31,6 +31,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from .configuration_blt import (
BLTConfig,
PatchingModeEnum,
TransformersLayerConfig,
)
if is_torch_flex_attn_available():
@ -153,7 +154,7 @@ class BLTRMSNorm(nn.Module):
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
class BLTSelfAttention(nn.Module):
def __init__(self, config: BLTConfig, layer_idx: int):
def __init__(self, config: TransformersLayerConfig, layer_idx: int):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
@ -233,9 +234,10 @@ class BLTSelfAttention(nn.Module):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_mllama.MllamaSelfAttentionDecoderLayer
# Copied from transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer
class BLTTransformerLayer(nn.Module):
def __init__(self, config: BLTConfig, layer_idx: int):
def __init__(self, config: TransformersLayerConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
@ -480,7 +482,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) ->
class BLTRotaryEmbedding(nn.Module):
def __init__(self, config: BLTConfig, device=None):
def __init__(self, config: TransformersLayerConfig, device=None):
super().__init__()
self.rope_type = config.rope_scaling["rope_type"]
self.max_seq_len_cached = config.max_position_embeddings
@ -528,19 +530,9 @@ class BLTLocalEncoder(nn.Module):
self.norm_eps = config.norm_eps
self.sliding_window = config.sliding_window
# Set up config for layers with proper dimensions
encoder_config = config
encoder_config.hidden_size = self.dim_local_encoder
encoder_config.num_attention_heads = self.n_heads_local_encoder
encoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_encoder
encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder
encoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_encoder / 3) + config.multiple_of - 1) // config.multiple_of)
self.layers = nn.ModuleList([BLTTransformerLayer(config.encoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)])
self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)])
# Set up config for rotary embedding
encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen
self.rotary_emb = BLTRotaryEmbedding(config=encoder_config)
self.rotary_emb = BLTRotaryEmbedding(config=config.encoder_layer_config)
self.token_embedding_projection = (
nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False)
@ -552,7 +544,6 @@ class BLTLocalEncoder(nn.Module):
self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder)
# Initialize cross attention layers only if cross attention is enabled
self.cross_attn_layers = None
if self.cross_attn_encoder and self.cross_attn_nheads is not None:
self.cross_attn_layers = torch.nn.ModuleList()
@ -680,19 +671,9 @@ class BLTLocalDecoder(nn.Module):
self.cross_attn_k = config.cross_attn_k
self.sliding_window = config.sliding_window
# Set up config for layers with proper dimensions
decoder_config = config
decoder_config.hidden_size = self.dim_local_decoder
decoder_config.num_attention_heads = self.n_heads_local_decoder
decoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_decoder
decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder
decoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_decoder / 3) + config.multiple_of - 1) // config.multiple_of)
self.layers = nn.ModuleList([BLTTransformerLayer(config.decoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)])
self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)])
decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen
self.rotary_emb = BLTRotaryEmbedding(config=decoder_config)
self.rotary_emb = BLTRotaryEmbedding(config=config.decoder_layer_config)
self.token_embedding_projection = (
nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False)
@ -704,7 +685,6 @@ class BLTLocalDecoder(nn.Module):
self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps)
# Initialize cross attention layers only if cross attention is enabled
self.cross_attn_layers = None
if self.cross_attn_decoder and self.cross_attn_nheads is not None:
self.cross_attn_layers = torch.nn.ModuleList()
@ -789,7 +769,7 @@ class BLTLocalDecoder(nn.Module):
logits = self.lm_head(self.norm(hidden_states))
return logits, cache
# Modified from transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention
class BLTCrossAttention(nn.Module):
"""Cross-attention module for BLT, following transformers style"""
@ -898,25 +878,16 @@ class BLTGlobalTransformer(nn.Module):
def __init__(self, config):
super().__init__()
# Extract config values to instance attributes
self.dim_global = config.dim_global
self.n_heads_global = config.n_heads_global
self.n_layers_global = config.n_layers_global
self.dropout = config.dropout
# Set up config for layers with proper dimensions
global_config = config
global_config.hidden_size = self.dim_global
global_config.num_attention_heads = self.n_heads_global
global_config.num_key_value_heads = getattr(config, 'n_kv_heads_global', None) or self.n_heads_global
global_config.head_dim = self.dim_global // self.n_heads_global
global_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_global / 3) + config.multiple_of - 1) // config.multiple_of)
self.layers = nn.ModuleList()
for layer_idx in range(self.n_layers_global):
self.layers.append(BLTTransformerLayer(global_config, layer_idx))
self.layers.append(BLTTransformerLayer(config.global_layer_config, layer_idx))
self.rotary_emb = BLTRotaryEmbedding(config=global_config)
self.rotary_emb = BLTRotaryEmbedding(config=config.global_layer_config)
self.token_embedding_projection = None
if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global:
@ -1292,17 +1263,8 @@ class BLTPatcher(BLTPreTrainedModel):
self.rotary_emb = BLTRotaryEmbedding(config=self.config)
self.layers = nn.ModuleList()
# Set up config for layers with proper dimensions
patcher_config = self.config
patcher_config.hidden_size = self.config.dim
patcher_config.num_attention_heads = self.config.n_heads
patcher_config.num_key_value_heads = getattr(self.config, 'n_kv_heads', None) or self.config.n_heads
patcher_config.head_dim = self.config.dim // self.config.n_heads
patcher_config.intermediate_size = self.config.multiple_of * ((int(8 * self.config.dim / 3) + self.config.multiple_of - 1) // self.config.multiple_of)
for layer_idx in range(self.config.n_layers):
self.layers.append(BLTTransformerLayer(patcher_config, layer_idx))
self.layers.append(BLTTransformerLayer(config.patcher_layer_config, layer_idx))
self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim)