clean up patcher helpers further

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-25 14:55:50 +00:00 committed by ita.zaporozhets@huggingface.co
parent 5c218735fb
commit 09ad5a51ea

View File

@ -40,14 +40,11 @@ from .configuration_blt import (
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@ -59,7 +56,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP
class BLTMLP(nn.Module):
def __init__(self, config):
super().__init__()
@ -75,7 +71,6 @@ class BLTMLP(nn.Module):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
@ -135,7 +130,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_rot.type_as(q), k_rot.type_as(k)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class BLTRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
@ -235,7 +229,6 @@ class BLTTransformerLayer(nn.Module):
return outputs
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
class BLTSelfAttention(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
@ -316,59 +309,66 @@ class BLTSelfAttention(nn.Module):
attn_weights = None
return attn_output, attn_weights, past_key_value
def check_non_zero_after_zero(tensor):
zero_mask = tensor == 0
shifted_mask = torch.cat(
[
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
zero_mask[:, :-1],
],
dim=1,
)
non_zero_after_zero = (tensor != 0) & shifted_mask
return non_zero_after_zero.any()
def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0):
primes = [
1000000007,
5915587277,
1500450271,
3267000013,
5754853343,
4093082899,
9576890767,
3628273133,
2860486313,
5463458053,
3367900313,
1000000007, 5915587277, 1500450271, 3267000013, 5754853343,
4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313,
]
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device)
prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])])
powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
prime_powers = prime ** powers
return torch.sum(token_tensor * prime_powers, dim=-1)
def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
"""
Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash].
expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab.
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
Note: max hash can make a big difference on the number of collisions.
"""
"""Hash token groups and map to range [0, max_hash]."""
with torch.no_grad():
batch_size, seq_len = token_ids.shape
prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
token_ids = torch.cat([prefix, token_ids], dim=1)
windows = token_ids.unfold(1, group_size, 1)
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
# Add padding for sliding window
padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
padded_tokens = torch.cat([padding, token_ids], dim=1)
# Create sliding windows and compute hashes
windows = padded_tokens.unfold(1, group_size, 1)
hashes = rolling_polynomial_hash(windows, hash_func_nb)
hash_values_range = hashes % max_hash
hash_values_range.requires_grad = False
return hash_values_range
hash_values = hashes % max_hash
hash_values.requires_grad = False
return hash_values
def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list):
"""Initialize hash-based token embeddings for the BLT encoder."""
num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size)
embeddings = [
nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim)
for _ in range(num_embeddings)
]
return nn.ModuleList(embeddings)
def compute_hash_embeddings(
local_encoder_tokens: torch.Tensor,
local_encoder,
encoder_hash_tok_embedding: nn.ModuleList,
encoder_hash_byte_group_nb_functions: int,
encoder_hash_byte_group_size: list,
encoder_hash_byte_group_vocab: int,
) -> torch.Tensor:
"""Compute token embeddings enhanced with hash-based embeddings."""
embeddings = local_encoder.embed_tokens(local_encoder_tokens)
embedding_idx = 0
for func_nb in range(encoder_hash_byte_group_nb_functions):
for group_size in encoder_hash_byte_group_size:
hash_ids = byte_group_hash_function(
local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab
)
embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids)
embedding_idx += 1
return embeddings
def _prepare_patch_cross_attention_mask(
@ -455,34 +455,47 @@ def _prepare_patch_cross_attention_mask(
return cross_attention_mask, full_text_row_masked_out_mask
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
#TODO: refactor to be more readable
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
"""
Splits patch lengths into smaller segments if they exceed `max_patch_length`.
Pads the result to uniform length across the batch.
Args:
patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
max_patch_length (int, optional): Maximum allowed length per patch.
Returns:
torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
"""
if max_patch_length is None:
return patch_lengths
batch_size = patch_lengths.size(0)
split_all = []
max_len = 0
processed = []
for seq in patch_lengths:
splits = []
for length in seq[seq > 0]:
# Split long patches into max_patch_length chunks
full, rem = divmod(length.item(), max_patch_length)
splits.extend([max_patch_length] * full + ([rem] if rem else []))
split_all.append(splits)
max_len = max(max_len, len(splits))
length = length.item()
full_chunks, remainder = divmod(length, max_patch_length)
splits.extend([max_patch_length] * full_chunks)
if remainder:
splits.append(remainder)
processed.append(splits)
# Pad sequences to the maximum length
# Find max length to pad to
max_len = max(len(splits) for splits in processed)
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
for i, splits in enumerate(split_all):
for i, splits in enumerate(processed):
if splits:
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
# Trim trailing columns that are all zeros
last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
if last_non_zero < padded.shape[1]:
padded = padded[:, :padded.shape[1] - last_non_zero]
# Trim zero columns
if (padded != 0).any(dim=0).sum() < padded.shape[1]:
last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
padded = padded[:, :last_nonzero]
return padded
@ -627,12 +640,11 @@ class BLTLocalDecoder(nn.Module):
def __init__(self, config: BLTLocalDecoderConfig):
super().__init__()
# Extract config values to instance attributes
self.hidden_size = config.hidden_size
self.vocab_size=config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.dropout = config.dropout
self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove
self.cross_attn_decoder = True
self.cross_attn_all_layers = config.cross_attn_all_layers
self.cross_attn_k = config.cross_attn_k
@ -655,11 +667,7 @@ class BLTLocalDecoder(nn.Module):
BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size)
)
self.lm_head = nn.Linear(
self.hidden_size,
self.vocab_size,
bias=False,
)
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False,)
def forward(
@ -851,48 +859,6 @@ class BLTGlobalTransformer(nn.Module):
return hidden_states, cache
def compute_hash_embeddings(
local_encoder_tokens: torch.Tensor,
local_encoder,
encoder_hash_tok_embedding: nn.ModuleList,
encoder_hash_byte_group_nb_functions: int,
encoder_hash_byte_group_size: list,
encoder_hash_byte_group_vocab: int,
) -> torch.Tensor:
"""
Compute embeddings using hash token embeddings.
Args:
local_encoder_tokens: Input tokens tensor
local_encoder: Encoder object with embed_tokens method
encoder_hash_tok_embedding: ModuleList of hash token embeddings
encoder_hash_byte_group_nb_functions: Number of hash functions
encoder_hash_byte_group_size: List of byte group sizes
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
Returns:
torch.Tensor: Combined embeddings
"""
if encoder_hash_tok_embedding is None:
return None
local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens)
i = 0
for func_nb in range(encoder_hash_byte_group_nb_functions):
for byte_group_size in encoder_hash_byte_group_size:
hash_ids = byte_group_hash_function(
local_encoder_tokens,
byte_group_size,
hash_func_nb=func_nb,
max_hash=encoder_hash_byte_group_vocab,
)
hash_tok_embedding = encoder_hash_tok_embedding[i]
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
i += 1
assert i == len(encoder_hash_tok_embedding)
return local_encoder_embeds
class BLTPreTrainedModel(PreTrainedModel):
@ -982,7 +948,7 @@ class BLTModel(BLTPreTrainedModel):
# Patcher initialization
if self.patch_in_forward:
self.patcher = BLTPatcher(config)
self.patcher = BLTPatcher(config.patcher_config)
self.patcher.eval()
for param in self.patcher.parameters():
param.requires_grad = False
@ -1076,8 +1042,8 @@ class BLTModel(BLTPreTrainedModel):
class BLTPatcher(BLTPreTrainedModel):
def __init__(self, config):
super().__init__(config.patcher_config)
def __init__(self, config: BLTPatcherConfig):
super().__init__(config)
self.rotary_emb = BLTRotaryEmbedding(config=self.config)
@ -1223,34 +1189,6 @@ class BLTPatcher(BLTPreTrainedModel):
return patch_lengths
def init_hash_embeddings(
config,
local_encoder_dim: int,
encoder_hash_byte_group_size: list,
):
"""Initialize hash-based token embeddings for the BLT encoder."""
if config.encoder_hash_byte_group_size is None:
return None
embeddings = []
emb_dim = local_encoder_dim
encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
for _ in range(config.encoder_hash_byte_group_nb_functions):
for _ in encoder_hash_byte_group_size:
embeddings.append(
nn.Embedding(
encoder_hash_byte_group_vocab,
emb_dim,
)
)
return nn.ModuleList(embeddings)
__all__ = [
"BLTPreTrainedModel",
"BLTModel",