some config renaming

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-26 12:26:42 +00:00 committed by ita.zaporozhets@huggingface.co
parent 5eb9a9426a
commit ef1147e128
2 changed files with 57 additions and 40 deletions

View File

@ -258,7 +258,7 @@ class BLTPatcherConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.dim = dim self.hidden_size = dim
self.n_layers = n_layers self.n_layers = n_layers
self.n_heads = n_heads self.n_heads = n_heads
self.head_dim = head_dim if head_dim is not None else (dim // n_heads) self.head_dim = head_dim if head_dim is not None else (dim // n_heads)
@ -288,7 +288,7 @@ class BLTPatcherConfig(PretrainedConfig):
self.hidden_act = "silu" # BLT uses silu activation self.hidden_act = "silu" # BLT uses silu activation
# Calculate intermediate_size using BLTMLP logic based on actual hidden_size # Calculate intermediate_size using BLTMLP logic based on actual hidden_size
self.intermediate_size = multiple_of * ((int(8 * dim / 3) + multiple_of - 1) // multiple_of) self.intermediate_size = multiple_of * ((int(8 * self.hidden_size / 3) + multiple_of - 1) // multiple_of)
# Set simple rope scaling for patcher (no complex dynamic rope) # Set simple rope scaling for patcher (no complex dynamic rope)
self.rope_scaling = {"rope_type": "default"} self.rope_scaling = {"rope_type": "default"}
@ -305,20 +305,20 @@ class BLTConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 256): vocab_size (`int`, *optional*, defaults to 256):
Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
max_seqlen (`int`, *optional*, defaults to 1024): max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model can handle. The maximum sequence length that this model can handle.
# Main architecture dimensions # Main architecture dimensions
dim (`int`, *optional*, defaults to 512): hidden_size (`int`, *optional*, defaults to 512):
Main dimension of the model. Main dimension of the model.
n_layers (`int`, *optional*, defaults to 8): num_hidden_layers (`int`, *optional*, defaults to 8):
Number of layers in the main transformer. Number of layers in the main transformer.
n_heads (`int`, *optional*, defaults to 8): num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads in the main transformer. Number of attention heads in the main transformer.
head_dim (`int`, *optional*): head_dim (`int`, *optional*):
Dimension of each attention head. If not specified, computed as dim // n_heads. Dimension of each attention head. If not specified, computed as hidden_size // num_attention_heads.
n_kv_heads (`int`, *optional*): num_key_value_heads (`int`, *optional*):
Number of key-value heads for grouped query attention. If not specified, defaults to n_heads. Number of key-value heads for grouped query attention. If not specified, defaults to num_attention_heads.
# Component-specific dimensions # Component-specific dimensions
dim_global (`int`, *optional*, defaults to 512): dim_global (`int`, *optional*, defaults to 512):
@ -464,13 +464,13 @@ class BLTConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=256, vocab_size=256,
max_seqlen=1024, max_position_embeddings=1024,
# Main architecture dimensions # Main architecture dimensions
dim=512, hidden_size=512,
n_layers=8, num_hidden_layers=8,
n_heads=8, num_attention_heads=8,
head_dim=None, head_dim=None,
n_kv_heads=None, num_key_value_heads=None,
# Component-specific dimensions # Component-specific dimensions
dim_global=512, dim_global=512,
dim_local_decoder=512, dim_local_decoder=512,
@ -542,14 +542,14 @@ class BLTConfig(PretrainedConfig):
# Basic model configuration # Basic model configuration
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_seqlen = max_seqlen self.max_position_embeddings = max_position_embeddings
# Main architecture dimensions # Main architecture dimensions
self.dim = dim self.hidden_size = hidden_size
self.n_layers = n_layers self.num_hidden_layers = num_hidden_layers
self.n_heads = n_heads self.num_attention_heads = num_attention_heads
self.head_dim = head_dim if head_dim is not None else (dim // n_heads) self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads)
self.n_kv_heads = n_kv_heads self.num_key_value_heads = num_key_value_heads
# Component-specific dimensions # Component-specific dimensions
self.dim_global = dim_global self.dim_global = dim_global
@ -630,11 +630,11 @@ class BLTConfig(PretrainedConfig):
pm_size=pm_size, pm_size=pm_size,
hidden_size=dim_local_encoder, hidden_size=dim_local_encoder,
num_attention_heads=n_heads_local_encoder, num_attention_heads=n_heads_local_encoder,
num_key_value_heads=n_kv_heads, num_key_value_heads=num_key_value_heads,
num_hidden_layers=n_layers_local_encoder, num_hidden_layers=n_layers_local_encoder,
norm_eps=norm_eps, norm_eps=norm_eps,
dropout=dropout, dropout=dropout,
max_position_embeddings=max_encoder_seq_length or max_seqlen, max_position_embeddings=max_encoder_seq_length or max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling={"rope_type": "default"}, rope_scaling={"rope_type": "default"},
hidden_act=hidden_act, hidden_act=hidden_act,
@ -648,11 +648,11 @@ class BLTConfig(PretrainedConfig):
dim_global=dim_global, dim_global=dim_global,
hidden_size=dim_local_decoder, hidden_size=dim_local_decoder,
num_attention_heads=n_heads_local_decoder, num_attention_heads=n_heads_local_decoder,
num_key_value_heads=n_kv_heads, num_key_value_heads=num_key_value_heads,
num_hidden_layers=n_layers_local_decoder, num_hidden_layers=n_layers_local_decoder,
norm_eps=norm_eps, norm_eps=norm_eps,
dropout=dropout, dropout=dropout,
max_position_embeddings=max_encoder_seq_length or max_seqlen, max_position_embeddings=max_encoder_seq_length or max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling={"rope_type": "default"}, rope_scaling={"rope_type": "default"},
hidden_act=hidden_act, hidden_act=hidden_act,
@ -666,7 +666,7 @@ class BLTConfig(PretrainedConfig):
num_hidden_layers=n_layers_global, num_hidden_layers=n_layers_global,
norm_eps=norm_eps, norm_eps=norm_eps,
dropout=dropout, dropout=dropout,
max_position_embeddings=max_seqlen, max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling={"rope_type": "default"}, rope_scaling={"rope_type": "default"},
hidden_act=hidden_act, hidden_act=hidden_act,
@ -690,7 +690,7 @@ class BLTConfig(PretrainedConfig):
# Set compatibility attributes for transformers # Set compatibility attributes for transformers
self.num_key_value_heads = n_heads_local_encoder self.num_key_value_heads = n_heads_local_encoder
self.max_position_embeddings = max_seqlen self.max_position_embeddings = max_position_embeddings
self.hidden_size = dim_local_encoder self.hidden_size = dim_local_encoder
self.num_attention_heads = n_heads_local_encoder self.num_attention_heads = n_heads_local_encoder
@ -705,6 +705,8 @@ class BLTConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
__all__ = [ __all__ = [
"BLTConfig", "BLTConfig",
"BLTPatcherConfig", "BLTPatcherConfig",
@ -714,3 +716,4 @@ __all__ = [
"InitStdFactor", "InitStdFactor",
"PatchingModeEnum" "PatchingModeEnum"
] ]

View File

@ -228,33 +228,47 @@ def _prepare_patch_cross_attention_mask(
return cross_attention_mask, full_text_row_masked_out_mask return cross_attention_mask, full_text_row_masked_out_mask
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
"""
Splits patch lengths into smaller segments if they exceed `max_patch_length`.
Pads the result to uniform length across the batch.
Args:
patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
max_patch_length (int, optional): Maximum allowed length per patch.
Returns:
torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
"""
if max_patch_length is None: if max_patch_length is None:
return patch_lengths return patch_lengths
batch_size = patch_lengths.size(0) batch_size = patch_lengths.size(0)
split_all = [] processed = []
max_len = 0
for seq in patch_lengths: for seq in patch_lengths:
splits = [] splits = []
for length in seq[seq > 0]: for length in seq[seq > 0]:
# Split long patches into max_patch_length chunks length = length.item()
full, rem = divmod(length.item(), max_patch_length) full_chunks, remainder = divmod(length, max_patch_length)
splits.extend([max_patch_length] * full + ([rem] if rem else [])) splits.extend([max_patch_length] * full_chunks)
split_all.append(splits) if remainder:
max_len = max(max_len, len(splits)) splits.append(remainder)
processed.append(splits)
# Pad sequences to the maximum length # Find max length to pad to
max_len = max(len(splits) for splits in processed)
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
for i, splits in enumerate(split_all):
for i, splits in enumerate(processed):
if splits: if splits:
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
# Trim trailing columns that are all zeros # Trim zero columns
last_non_zero = (padded != 0).flip(1).int().argmax(1).min() if (padded != 0).any(dim=0).sum() < padded.shape[1]:
if last_non_zero < padded.shape[1]: last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
padded = padded[:, :padded.shape[1] - last_non_zero] padded = padded[:, :last_nonzero]
return padded return padded