mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
some config renaming
This commit is contained in:
parent
5eb9a9426a
commit
ef1147e128
@ -258,7 +258,7 @@ class BLTPatcherConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.hidden_size = dim
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim if head_dim is not None else (dim // n_heads)
|
||||
@ -288,7 +288,7 @@ class BLTPatcherConfig(PretrainedConfig):
|
||||
self.hidden_act = "silu" # BLT uses silu activation
|
||||
|
||||
# Calculate intermediate_size using BLTMLP logic based on actual hidden_size
|
||||
self.intermediate_size = multiple_of * ((int(8 * dim / 3) + multiple_of - 1) // multiple_of)
|
||||
self.intermediate_size = multiple_of * ((int(8 * self.hidden_size / 3) + multiple_of - 1) // multiple_of)
|
||||
|
||||
# Set simple rope scaling for patcher (no complex dynamic rope)
|
||||
self.rope_scaling = {"rope_type": "default"}
|
||||
@ -305,20 +305,20 @@ class BLTConfig(PretrainedConfig):
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256):
|
||||
Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
|
||||
max_seqlen (`int`, *optional*, defaults to 1024):
|
||||
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model can handle.
|
||||
|
||||
# Main architecture dimensions
|
||||
dim (`int`, *optional*, defaults to 512):
|
||||
hidden_size (`int`, *optional*, defaults to 512):
|
||||
Main dimension of the model.
|
||||
n_layers (`int`, *optional*, defaults to 8):
|
||||
num_hidden_layers (`int`, *optional*, defaults to 8):
|
||||
Number of layers in the main transformer.
|
||||
n_heads (`int`, *optional*, defaults to 8):
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads in the main transformer.
|
||||
head_dim (`int`, *optional*):
|
||||
Dimension of each attention head. If not specified, computed as dim // n_heads.
|
||||
n_kv_heads (`int`, *optional*):
|
||||
Number of key-value heads for grouped query attention. If not specified, defaults to n_heads.
|
||||
Dimension of each attention head. If not specified, computed as hidden_size // num_attention_heads.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
Number of key-value heads for grouped query attention. If not specified, defaults to num_attention_heads.
|
||||
|
||||
# Component-specific dimensions
|
||||
dim_global (`int`, *optional*, defaults to 512):
|
||||
@ -464,13 +464,13 @@ class BLTConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
max_seqlen=1024,
|
||||
max_position_embeddings=1024,
|
||||
# Main architecture dimensions
|
||||
dim=512,
|
||||
n_layers=8,
|
||||
n_heads=8,
|
||||
hidden_size=512,
|
||||
num_hidden_layers=8,
|
||||
num_attention_heads=8,
|
||||
head_dim=None,
|
||||
n_kv_heads=None,
|
||||
num_key_value_heads=None,
|
||||
# Component-specific dimensions
|
||||
dim_global=512,
|
||||
dim_local_decoder=512,
|
||||
@ -542,14 +542,14 @@ class BLTConfig(PretrainedConfig):
|
||||
|
||||
# Basic model configuration
|
||||
self.vocab_size = vocab_size
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Main architecture dimensions
|
||||
self.dim = dim
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim if head_dim is not None else (dim // n_heads)
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads)
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
# Component-specific dimensions
|
||||
self.dim_global = dim_global
|
||||
@ -630,11 +630,11 @@ class BLTConfig(PretrainedConfig):
|
||||
pm_size=pm_size,
|
||||
hidden_size=dim_local_encoder,
|
||||
num_attention_heads=n_heads_local_encoder,
|
||||
num_key_value_heads=n_kv_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=n_layers_local_encoder,
|
||||
norm_eps=norm_eps,
|
||||
dropout=dropout,
|
||||
max_position_embeddings=max_encoder_seq_length or max_seqlen,
|
||||
max_position_embeddings=max_encoder_seq_length or max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling={"rope_type": "default"},
|
||||
hidden_act=hidden_act,
|
||||
@ -648,11 +648,11 @@ class BLTConfig(PretrainedConfig):
|
||||
dim_global=dim_global,
|
||||
hidden_size=dim_local_decoder,
|
||||
num_attention_heads=n_heads_local_decoder,
|
||||
num_key_value_heads=n_kv_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=n_layers_local_decoder,
|
||||
norm_eps=norm_eps,
|
||||
dropout=dropout,
|
||||
max_position_embeddings=max_encoder_seq_length or max_seqlen,
|
||||
max_position_embeddings=max_encoder_seq_length or max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling={"rope_type": "default"},
|
||||
hidden_act=hidden_act,
|
||||
@ -666,7 +666,7 @@ class BLTConfig(PretrainedConfig):
|
||||
num_hidden_layers=n_layers_global,
|
||||
norm_eps=norm_eps,
|
||||
dropout=dropout,
|
||||
max_position_embeddings=max_seqlen,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling={"rope_type": "default"},
|
||||
hidden_act=hidden_act,
|
||||
@ -690,7 +690,7 @@ class BLTConfig(PretrainedConfig):
|
||||
|
||||
# Set compatibility attributes for transformers
|
||||
self.num_key_value_heads = n_heads_local_encoder
|
||||
self.max_position_embeddings = max_seqlen
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = dim_local_encoder
|
||||
self.num_attention_heads = n_heads_local_encoder
|
||||
|
||||
@ -705,6 +705,8 @@ class BLTConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BLTConfig",
|
||||
"BLTPatcherConfig",
|
||||
@ -714,3 +716,4 @@ __all__ = [
|
||||
"InitStdFactor",
|
||||
"PatchingModeEnum"
|
||||
]
|
||||
|
||||
|
@ -228,33 +228,47 @@ def _prepare_patch_cross_attention_mask(
|
||||
|
||||
return cross_attention_mask, full_text_row_masked_out_mask
|
||||
|
||||
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
|
||||
|
||||
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
|
||||
"""
|
||||
Splits patch lengths into smaller segments if they exceed `max_patch_length`.
|
||||
Pads the result to uniform length across the batch.
|
||||
|
||||
Args:
|
||||
patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
|
||||
max_patch_length (int, optional): Maximum allowed length per patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
|
||||
"""
|
||||
if max_patch_length is None:
|
||||
return patch_lengths
|
||||
|
||||
batch_size = patch_lengths.size(0)
|
||||
split_all = []
|
||||
max_len = 0
|
||||
processed = []
|
||||
|
||||
for seq in patch_lengths:
|
||||
splits = []
|
||||
for length in seq[seq > 0]:
|
||||
# Split long patches into max_patch_length chunks
|
||||
full, rem = divmod(length.item(), max_patch_length)
|
||||
splits.extend([max_patch_length] * full + ([rem] if rem else []))
|
||||
split_all.append(splits)
|
||||
max_len = max(max_len, len(splits))
|
||||
length = length.item()
|
||||
full_chunks, remainder = divmod(length, max_patch_length)
|
||||
splits.extend([max_patch_length] * full_chunks)
|
||||
if remainder:
|
||||
splits.append(remainder)
|
||||
processed.append(splits)
|
||||
|
||||
# Pad sequences to the maximum length
|
||||
# Find max length to pad to
|
||||
max_len = max(len(splits) for splits in processed)
|
||||
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
|
||||
for i, splits in enumerate(split_all):
|
||||
|
||||
for i, splits in enumerate(processed):
|
||||
if splits:
|
||||
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
|
||||
|
||||
# Trim trailing columns that are all zeros
|
||||
last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
|
||||
if last_non_zero < padded.shape[1]:
|
||||
padded = padded[:, :padded.shape[1] - last_non_zero]
|
||||
# Trim zero columns
|
||||
if (padded != 0).any(dim=0).sum() < padded.shape[1]:
|
||||
last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
|
||||
padded = padded[:, :last_nonzero]
|
||||
|
||||
return padded
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user