some config renaming

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-26 12:26:42 +00:00 committed by ita.zaporozhets@huggingface.co
parent 5eb9a9426a
commit ef1147e128
2 changed files with 57 additions and 40 deletions

View File

@ -258,7 +258,7 @@ class BLTPatcherConfig(PretrainedConfig):
**kwargs,
):
self.vocab_size = vocab_size
self.dim = dim
self.hidden_size = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.head_dim = head_dim if head_dim is not None else (dim // n_heads)
@ -288,7 +288,7 @@ class BLTPatcherConfig(PretrainedConfig):
self.hidden_act = "silu" # BLT uses silu activation
# Calculate intermediate_size using BLTMLP logic based on actual hidden_size
self.intermediate_size = multiple_of * ((int(8 * dim / 3) + multiple_of - 1) // multiple_of)
self.intermediate_size = multiple_of * ((int(8 * self.hidden_size / 3) + multiple_of - 1) // multiple_of)
# Set simple rope scaling for patcher (no complex dynamic rope)
self.rope_scaling = {"rope_type": "default"}
@ -305,20 +305,20 @@ class BLTConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 256):
Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
max_seqlen (`int`, *optional*, defaults to 1024):
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model can handle.
# Main architecture dimensions
dim (`int`, *optional*, defaults to 512):
hidden_size (`int`, *optional*, defaults to 512):
Main dimension of the model.
n_layers (`int`, *optional*, defaults to 8):
num_hidden_layers (`int`, *optional*, defaults to 8):
Number of layers in the main transformer.
n_heads (`int`, *optional*, defaults to 8):
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads in the main transformer.
head_dim (`int`, *optional*):
Dimension of each attention head. If not specified, computed as dim // n_heads.
n_kv_heads (`int`, *optional*):
Number of key-value heads for grouped query attention. If not specified, defaults to n_heads.
Dimension of each attention head. If not specified, computed as hidden_size // num_attention_heads.
num_key_value_heads (`int`, *optional*):
Number of key-value heads for grouped query attention. If not specified, defaults to num_attention_heads.
# Component-specific dimensions
dim_global (`int`, *optional*, defaults to 512):
@ -464,13 +464,13 @@ class BLTConfig(PretrainedConfig):
def __init__(
self,
vocab_size=256,
max_seqlen=1024,
max_position_embeddings=1024,
# Main architecture dimensions
dim=512,
n_layers=8,
n_heads=8,
hidden_size=512,
num_hidden_layers=8,
num_attention_heads=8,
head_dim=None,
n_kv_heads=None,
num_key_value_heads=None,
# Component-specific dimensions
dim_global=512,
dim_local_decoder=512,
@ -542,14 +542,14 @@ class BLTConfig(PretrainedConfig):
# Basic model configuration
self.vocab_size = vocab_size
self.max_seqlen = max_seqlen
self.max_position_embeddings = max_position_embeddings
# Main architecture dimensions
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.head_dim = head_dim if head_dim is not None else (dim // n_heads)
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads)
self.num_key_value_heads = num_key_value_heads
# Component-specific dimensions
self.dim_global = dim_global
@ -630,11 +630,11 @@ class BLTConfig(PretrainedConfig):
pm_size=pm_size,
hidden_size=dim_local_encoder,
num_attention_heads=n_heads_local_encoder,
num_key_value_heads=n_kv_heads,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=n_layers_local_encoder,
norm_eps=norm_eps,
dropout=dropout,
max_position_embeddings=max_encoder_seq_length or max_seqlen,
max_position_embeddings=max_encoder_seq_length or max_position_embeddings,
rope_theta=rope_theta,
rope_scaling={"rope_type": "default"},
hidden_act=hidden_act,
@ -648,11 +648,11 @@ class BLTConfig(PretrainedConfig):
dim_global=dim_global,
hidden_size=dim_local_decoder,
num_attention_heads=n_heads_local_decoder,
num_key_value_heads=n_kv_heads,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=n_layers_local_decoder,
norm_eps=norm_eps,
dropout=dropout,
max_position_embeddings=max_encoder_seq_length or max_seqlen,
max_position_embeddings=max_encoder_seq_length or max_position_embeddings,
rope_theta=rope_theta,
rope_scaling={"rope_type": "default"},
hidden_act=hidden_act,
@ -666,7 +666,7 @@ class BLTConfig(PretrainedConfig):
num_hidden_layers=n_layers_global,
norm_eps=norm_eps,
dropout=dropout,
max_position_embeddings=max_seqlen,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rope_scaling={"rope_type": "default"},
hidden_act=hidden_act,
@ -690,7 +690,7 @@ class BLTConfig(PretrainedConfig):
# Set compatibility attributes for transformers
self.num_key_value_heads = n_heads_local_encoder
self.max_position_embeddings = max_seqlen
self.max_position_embeddings = max_position_embeddings
self.hidden_size = dim_local_encoder
self.num_attention_heads = n_heads_local_encoder
@ -705,6 +705,8 @@ class BLTConfig(PretrainedConfig):
**kwargs,
)
__all__ = [
"BLTConfig",
"BLTPatcherConfig",
@ -714,3 +716,4 @@ __all__ = [
"InitStdFactor",
"PatchingModeEnum"
]

View File

@ -228,33 +228,47 @@ def _prepare_patch_cross_attention_mask(
return cross_attention_mask, full_text_row_masked_out_mask
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
"""
Splits patch lengths into smaller segments if they exceed `max_patch_length`.
Pads the result to uniform length across the batch.
Args:
patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
max_patch_length (int, optional): Maximum allowed length per patch.
Returns:
torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
"""
if max_patch_length is None:
return patch_lengths
batch_size = patch_lengths.size(0)
split_all = []
max_len = 0
processed = []
for seq in patch_lengths:
splits = []
for length in seq[seq > 0]:
# Split long patches into max_patch_length chunks
full, rem = divmod(length.item(), max_patch_length)
splits.extend([max_patch_length] * full + ([rem] if rem else []))
split_all.append(splits)
max_len = max(max_len, len(splits))
length = length.item()
full_chunks, remainder = divmod(length, max_patch_length)
splits.extend([max_patch_length] * full_chunks)
if remainder:
splits.append(remainder)
processed.append(splits)
# Pad sequences to the maximum length
# Find max length to pad to
max_len = max(len(splits) for splits in processed)
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
for i, splits in enumerate(split_all):
for i, splits in enumerate(processed):
if splits:
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
# Trim trailing columns that are all zeros
last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
if last_non_zero < padded.shape[1]:
padded = padded[:, :padded.shape[1] - last_non_zero]
# Trim zero columns
if (padded != 0).any(dim=0).sum() < padded.shape[1]:
last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
padded = padded[:, :last_nonzero]
return padded