mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
BLTTransformerLayer config
This commit is contained in:
parent
d2e6902460
commit
9d318707e4
@ -37,6 +37,44 @@ class PatchingModeEnum(str, Enum):
|
|||||||
byte = "byte"
|
byte = "byte"
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersLayerConfig:
|
||||||
|
"""
|
||||||
|
Configuration class for BLT Transformer layers, providing all necessary parameters
|
||||||
|
for attention, MLP, and normalization components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
norm_eps: float,
|
||||||
|
dropout: float,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
rope_theta: float,
|
||||||
|
rope_scaling: dict,
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
self.dropout = dropout
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
|
||||||
|
# Add any additional kwargs as attributes
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
class BLTPatcherConfig(PretrainedConfig):
|
class BLTPatcherConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
Configuration class for the BLT Patcher/Entropy model component.
|
Configuration class for the BLT Patcher/Entropy model component.
|
||||||
@ -578,6 +616,63 @@ class BLTConfig(PretrainedConfig):
|
|||||||
# Note: Each component uses its own hidden dimension, not the main dim
|
# Note: Each component uses its own hidden dimension, not the main dim
|
||||||
self.intermediate_size = None # Will be calculated per component
|
self.intermediate_size = None # Will be calculated per component
|
||||||
|
|
||||||
|
# layer configurations as dictionaries (needed to be JSON serializable!)
|
||||||
|
self._encoder_layer_config_dict = {
|
||||||
|
"hidden_size": self.dim_local_encoder,
|
||||||
|
"num_attention_heads": self.n_heads_local_encoder,
|
||||||
|
"num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_encoder,
|
||||||
|
"head_dim": self.dim_local_encoder // self.n_heads_local_encoder,
|
||||||
|
"intermediate_size": self.multiple_of * ((int(8 * self.dim_local_encoder / 3) + self.multiple_of - 1) // self.multiple_of),
|
||||||
|
"norm_eps": self.norm_eps,
|
||||||
|
"dropout": self.dropout,
|
||||||
|
"max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen,
|
||||||
|
"rope_theta": self.rope_theta,
|
||||||
|
"rope_scaling": self.rope_scaling,
|
||||||
|
"hidden_act": self.hidden_act,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._decoder_layer_config_dict = {
|
||||||
|
"hidden_size": self.dim_local_decoder,
|
||||||
|
"num_attention_heads": self.n_heads_local_decoder,
|
||||||
|
"num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_decoder,
|
||||||
|
"head_dim": self.dim_local_decoder // self.n_heads_local_decoder,
|
||||||
|
"intermediate_size": self.multiple_of * ((int(8 * self.dim_local_decoder / 3) + self.multiple_of - 1) // self.multiple_of),
|
||||||
|
"norm_eps": self.norm_eps,
|
||||||
|
"dropout": self.dropout,
|
||||||
|
"max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen,
|
||||||
|
"rope_theta": self.rope_theta,
|
||||||
|
"rope_scaling": self.rope_scaling,
|
||||||
|
"hidden_act": self.hidden_act,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._global_layer_config_dict = {
|
||||||
|
"hidden_size": self.dim_global,
|
||||||
|
"num_attention_heads": self.n_heads_global,
|
||||||
|
"num_key_value_heads": getattr(self, 'n_kv_heads_global', None) or self.n_heads_global,
|
||||||
|
"head_dim": self.dim_global // self.n_heads_global,
|
||||||
|
"intermediate_size": self.multiple_of * ((int(8 * self.dim_global / 3) + self.multiple_of - 1) // self.multiple_of),
|
||||||
|
"norm_eps": self.norm_eps,
|
||||||
|
"dropout": self.dropout,
|
||||||
|
"max_position_embeddings": self.max_seqlen,
|
||||||
|
"rope_theta": self.rope_theta,
|
||||||
|
"rope_scaling": self.rope_scaling,
|
||||||
|
"hidden_act": self.hidden_act,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._patcher_layer_config_dict = {
|
||||||
|
"hidden_size": self.patcher_config.dim,
|
||||||
|
"num_attention_heads": self.patcher_config.n_heads,
|
||||||
|
"num_key_value_heads": getattr(self.patcher_config, 'n_kv_heads', None) or self.patcher_config.n_heads,
|
||||||
|
"head_dim": self.patcher_config.dim // self.patcher_config.n_heads,
|
||||||
|
"intermediate_size": self.patcher_config.multiple_of * ((int(8 * self.patcher_config.dim / 3) + self.patcher_config.multiple_of - 1) // self.patcher_config.multiple_of),
|
||||||
|
"norm_eps": self.patcher_config.norm_eps,
|
||||||
|
"dropout": self.patcher_config.dropout,
|
||||||
|
"max_position_embeddings": self.patcher_config.max_seqlen,
|
||||||
|
"rope_theta": self.patcher_config.rope_theta,
|
||||||
|
"rope_scaling": self.patcher_config.rope_scaling,
|
||||||
|
"hidden_act": self.hidden_act,
|
||||||
|
}
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
@ -585,6 +680,21 @@ class BLTConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoder_layer_config(self) -> TransformersLayerConfig:
|
||||||
|
return TransformersLayerConfig(**self._encoder_layer_config_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoder_layer_config(self) -> TransformersLayerConfig:
|
||||||
|
return TransformersLayerConfig(**self._decoder_layer_config_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_layer_config(self) -> TransformersLayerConfig:
|
||||||
|
return TransformersLayerConfig(**self._global_layer_config_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def patcher_layer_config(self) -> TransformersLayerConfig:
|
||||||
|
return TransformersLayerConfig(**self._patcher_layer_config_dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoder_dim_token_emb(self):
|
def encoder_dim_token_emb(self):
|
||||||
@ -648,6 +758,5 @@ class BLTConfig(PretrainedConfig):
|
|||||||
else: # DISABLED
|
else: # DISABLED
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
|
__all__ = ["BLTConfig", "BLTPatcherConfig", "TransformersLayerConfig", "InitStdFactor", "PatchingModeEnum"]
|
||||||
__all__ = ["BLTConfig", "BLTPatcherConfig", "InitStdFactor", "PatchingModeEnum"]
|
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|||||||
from .configuration_blt import (
|
from .configuration_blt import (
|
||||||
BLTConfig,
|
BLTConfig,
|
||||||
PatchingModeEnum,
|
PatchingModeEnum,
|
||||||
|
TransformersLayerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_flex_attn_available():
|
if is_torch_flex_attn_available():
|
||||||
@ -153,7 +154,7 @@ class BLTRMSNorm(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
|
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
|
||||||
class BLTSelfAttention(nn.Module):
|
class BLTSelfAttention(nn.Module):
|
||||||
def __init__(self, config: BLTConfig, layer_idx: int):
|
def __init__(self, config: TransformersLayerConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
@ -233,9 +234,10 @@ class BLTSelfAttention(nn.Module):
|
|||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_mllama.MllamaSelfAttentionDecoderLayer
|
|
||||||
|
# Copied from transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer
|
||||||
class BLTTransformerLayer(nn.Module):
|
class BLTTransformerLayer(nn.Module):
|
||||||
def __init__(self, config: BLTConfig, layer_idx: int):
|
def __init__(self, config: TransformersLayerConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -480,7 +482,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) ->
|
|||||||
|
|
||||||
|
|
||||||
class BLTRotaryEmbedding(nn.Module):
|
class BLTRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: BLTConfig, device=None):
|
def __init__(self, config: TransformersLayerConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_type = config.rope_scaling["rope_type"]
|
self.rope_type = config.rope_scaling["rope_type"]
|
||||||
self.max_seq_len_cached = config.max_position_embeddings
|
self.max_seq_len_cached = config.max_position_embeddings
|
||||||
@ -528,19 +530,9 @@ class BLTLocalEncoder(nn.Module):
|
|||||||
self.norm_eps = config.norm_eps
|
self.norm_eps = config.norm_eps
|
||||||
self.sliding_window = config.sliding_window
|
self.sliding_window = config.sliding_window
|
||||||
|
|
||||||
# Set up config for layers with proper dimensions
|
self.layers = nn.ModuleList([BLTTransformerLayer(config.encoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)])
|
||||||
encoder_config = config
|
|
||||||
encoder_config.hidden_size = self.dim_local_encoder
|
|
||||||
encoder_config.num_attention_heads = self.n_heads_local_encoder
|
|
||||||
encoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_encoder
|
|
||||||
encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder
|
|
||||||
encoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_encoder / 3) + config.multiple_of - 1) // config.multiple_of)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)])
|
self.rotary_emb = BLTRotaryEmbedding(config=config.encoder_layer_config)
|
||||||
|
|
||||||
# Set up config for rotary embedding
|
|
||||||
encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen
|
|
||||||
self.rotary_emb = BLTRotaryEmbedding(config=encoder_config)
|
|
||||||
|
|
||||||
self.token_embedding_projection = (
|
self.token_embedding_projection = (
|
||||||
nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False)
|
nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False)
|
||||||
@ -552,7 +544,6 @@ class BLTLocalEncoder(nn.Module):
|
|||||||
|
|
||||||
self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder)
|
self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder)
|
||||||
|
|
||||||
# Initialize cross attention layers only if cross attention is enabled
|
|
||||||
self.cross_attn_layers = None
|
self.cross_attn_layers = None
|
||||||
if self.cross_attn_encoder and self.cross_attn_nheads is not None:
|
if self.cross_attn_encoder and self.cross_attn_nheads is not None:
|
||||||
self.cross_attn_layers = torch.nn.ModuleList()
|
self.cross_attn_layers = torch.nn.ModuleList()
|
||||||
@ -680,19 +671,9 @@ class BLTLocalDecoder(nn.Module):
|
|||||||
self.cross_attn_k = config.cross_attn_k
|
self.cross_attn_k = config.cross_attn_k
|
||||||
self.sliding_window = config.sliding_window
|
self.sliding_window = config.sliding_window
|
||||||
|
|
||||||
# Set up config for layers with proper dimensions
|
self.layers = nn.ModuleList([BLTTransformerLayer(config.decoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)])
|
||||||
decoder_config = config
|
|
||||||
decoder_config.hidden_size = self.dim_local_decoder
|
|
||||||
decoder_config.num_attention_heads = self.n_heads_local_decoder
|
|
||||||
decoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_decoder
|
|
||||||
decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder
|
|
||||||
decoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_decoder / 3) + config.multiple_of - 1) // config.multiple_of)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)])
|
self.rotary_emb = BLTRotaryEmbedding(config=config.decoder_layer_config)
|
||||||
|
|
||||||
decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen
|
|
||||||
|
|
||||||
self.rotary_emb = BLTRotaryEmbedding(config=decoder_config)
|
|
||||||
|
|
||||||
self.token_embedding_projection = (
|
self.token_embedding_projection = (
|
||||||
nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False)
|
nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False)
|
||||||
@ -704,7 +685,6 @@ class BLTLocalDecoder(nn.Module):
|
|||||||
|
|
||||||
self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps)
|
self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps)
|
||||||
|
|
||||||
# Initialize cross attention layers only if cross attention is enabled
|
|
||||||
self.cross_attn_layers = None
|
self.cross_attn_layers = None
|
||||||
if self.cross_attn_decoder and self.cross_attn_nheads is not None:
|
if self.cross_attn_decoder and self.cross_attn_nheads is not None:
|
||||||
self.cross_attn_layers = torch.nn.ModuleList()
|
self.cross_attn_layers = torch.nn.ModuleList()
|
||||||
@ -789,7 +769,7 @@ class BLTLocalDecoder(nn.Module):
|
|||||||
logits = self.lm_head(self.norm(hidden_states))
|
logits = self.lm_head(self.norm(hidden_states))
|
||||||
return logits, cache
|
return logits, cache
|
||||||
|
|
||||||
# Modified from transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention
|
|
||||||
class BLTCrossAttention(nn.Module):
|
class BLTCrossAttention(nn.Module):
|
||||||
"""Cross-attention module for BLT, following transformers style"""
|
"""Cross-attention module for BLT, following transformers style"""
|
||||||
|
|
||||||
@ -898,25 +878,16 @@ class BLTGlobalTransformer(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Extract config values to instance attributes
|
|
||||||
self.dim_global = config.dim_global
|
self.dim_global = config.dim_global
|
||||||
self.n_heads_global = config.n_heads_global
|
self.n_heads_global = config.n_heads_global
|
||||||
self.n_layers_global = config.n_layers_global
|
self.n_layers_global = config.n_layers_global
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
|
|
||||||
# Set up config for layers with proper dimensions
|
|
||||||
global_config = config
|
|
||||||
global_config.hidden_size = self.dim_global
|
|
||||||
global_config.num_attention_heads = self.n_heads_global
|
|
||||||
global_config.num_key_value_heads = getattr(config, 'n_kv_heads_global', None) or self.n_heads_global
|
|
||||||
global_config.head_dim = self.dim_global // self.n_heads_global
|
|
||||||
global_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_global / 3) + config.multiple_of - 1) // config.multiple_of)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
for layer_idx in range(self.n_layers_global):
|
for layer_idx in range(self.n_layers_global):
|
||||||
self.layers.append(BLTTransformerLayer(global_config, layer_idx))
|
self.layers.append(BLTTransformerLayer(config.global_layer_config, layer_idx))
|
||||||
|
|
||||||
self.rotary_emb = BLTRotaryEmbedding(config=global_config)
|
self.rotary_emb = BLTRotaryEmbedding(config=config.global_layer_config)
|
||||||
|
|
||||||
self.token_embedding_projection = None
|
self.token_embedding_projection = None
|
||||||
if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global:
|
if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global:
|
||||||
@ -1292,17 +1263,8 @@ class BLTPatcher(BLTPreTrainedModel):
|
|||||||
self.rotary_emb = BLTRotaryEmbedding(config=self.config)
|
self.rotary_emb = BLTRotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
# Set up config for layers with proper dimensions
|
|
||||||
patcher_config = self.config
|
|
||||||
patcher_config.hidden_size = self.config.dim
|
|
||||||
patcher_config.num_attention_heads = self.config.n_heads
|
|
||||||
patcher_config.num_key_value_heads = getattr(self.config, 'n_kv_heads', None) or self.config.n_heads
|
|
||||||
patcher_config.head_dim = self.config.dim // self.config.n_heads
|
|
||||||
patcher_config.intermediate_size = self.config.multiple_of * ((int(8 * self.config.dim / 3) + self.config.multiple_of - 1) // self.config.multiple_of)
|
|
||||||
|
|
||||||
for layer_idx in range(self.config.n_layers):
|
for layer_idx in range(self.config.n_layers):
|
||||||
self.layers.append(BLTTransformerLayer(patcher_config, layer_idx))
|
self.layers.append(BLTTransformerLayer(config.patcher_layer_config, layer_idx))
|
||||||
|
|
||||||
|
|
||||||
self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
|
self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user