diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 72ac1652284..3854ca5e139 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -40,14 +40,11 @@ from .configuration_blt import ( if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask - logger = logging.get_logger(__name__) -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -59,7 +56,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP class BLTMLP(nn.Module): def __init__(self, config): super().__init__() @@ -75,7 +71,6 @@ class BLTMLP(nn.Module): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj -# Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -135,7 +130,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText class BLTRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -235,7 +229,6 @@ class BLTTransformerLayer(nn.Module): return outputs -# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT class BLTSelfAttention(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() @@ -316,59 +309,66 @@ class BLTSelfAttention(nn.Module): attn_weights = None return attn_output, attn_weights, past_key_value - - - -def check_non_zero_after_zero(tensor): - zero_mask = tensor == 0 - shifted_mask = torch.cat( - [ - torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), - zero_mask[:, :-1], - ], - dim=1, - ) - non_zero_after_zero = (tensor != 0) & shifted_mask - return non_zero_after_zero.any() + def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, + 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, + 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, ] prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime ** powers return torch.sum(token_tensor * prime_powers, dim=-1) def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): - """ - Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. - - expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. - returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. - - Note: max hash can make a big difference on the number of collisions. - """ + """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): batch_size, seq_len = token_ids.shape - prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - token_ids = torch.cat([prefix, token_ids], dim=1) - windows = token_ids.unfold(1, group_size, 1) - # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values_range = hashes % max_hash - hash_values_range.requires_grad = False - return hash_values_range + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) + for _ in range(num_embeddings) + ] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings def _prepare_patch_cross_attention_mask( @@ -455,34 +455,47 @@ def _prepare_patch_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: - #TODO: refactor to be more readable + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ if max_patch_length is None: return patch_lengths batch_size = patch_lengths.size(0) - split_all = [] - max_len = 0 + processed = [] for seq in patch_lengths: splits = [] for length in seq[seq > 0]: - # Split long patches into max_patch_length chunks - full, rem = divmod(length.item(), max_patch_length) - splits.extend([max_patch_length] * full + ([rem] if rem else [])) - split_all.append(splits) - max_len = max(max_len, len(splits)) + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) - # Pad sequences to the maximum length + # Find max length to pad to + max_len = max(len(splits) for splits in processed) padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) - for i, splits in enumerate(split_all): + + for i, splits in enumerate(processed): if splits: padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - # Trim trailing columns that are all zeros - last_non_zero = (padded != 0).flip(1).int().argmax(1).min() - if last_non_zero < padded.shape[1]: - padded = padded[:, :padded.shape[1] - last_non_zero] + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] return padded @@ -627,12 +640,11 @@ class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTLocalDecoderConfig): super().__init__() - # Extract config values to instance attributes self.hidden_size = config.hidden_size self.vocab_size=config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_decoder = True self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k @@ -655,11 +667,7 @@ class BLTLocalDecoder(nn.Module): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) ) - self.lm_head = nn.Linear( - self.hidden_size, - self.vocab_size, - bias=False, - ) + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False,) def forward( @@ -851,48 +859,6 @@ class BLTGlobalTransformer(nn.Module): return hidden_states, cache -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """ - Compute embeddings using hash token embeddings. - - Args: - local_encoder_tokens: Input tokens tensor - local_encoder: Encoder object with embed_tokens method - encoder_hash_tok_embedding: ModuleList of hash token embeddings - encoder_hash_byte_group_nb_functions: Number of hash functions - encoder_hash_byte_group_size: List of byte group sizes - encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings - - Returns: - torch.Tensor: Combined embeddings - """ - if encoder_hash_tok_embedding is None: - return None - - local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens) - - i = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for byte_group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, - byte_group_size, - hash_func_nb=func_nb, - max_hash=encoder_hash_byte_group_vocab, - ) - hash_tok_embedding = encoder_hash_tok_embedding[i] - local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) - i += 1 - - assert i == len(encoder_hash_tok_embedding) - return local_encoder_embeds class BLTPreTrainedModel(PreTrainedModel): @@ -982,7 +948,7 @@ class BLTModel(BLTPreTrainedModel): # Patcher initialization if self.patch_in_forward: - self.patcher = BLTPatcher(config) + self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() for param in self.patcher.parameters(): param.requires_grad = False @@ -1076,8 +1042,8 @@ class BLTModel(BLTPreTrainedModel): class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config): - super().__init__(config.patcher_config) + def __init__(self, config: BLTPatcherConfig): + super().__init__(config) self.rotary_emb = BLTRotaryEmbedding(config=self.config) @@ -1223,34 +1189,6 @@ class BLTPatcher(BLTPreTrainedModel): return patch_lengths - -def init_hash_embeddings( - config, - local_encoder_dim: int, - encoder_hash_byte_group_size: list, -): - """Initialize hash-based token embeddings for the BLT encoder.""" - if config.encoder_hash_byte_group_size is None: - return None - - embeddings = [] - emb_dim = local_encoder_dim - encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - - for _ in range(config.encoder_hash_byte_group_nb_functions): - for _ in encoder_hash_byte_group_size: - embeddings.append( - nn.Embedding( - encoder_hash_byte_group_vocab, - emb_dim, - ) - ) - - return nn.ModuleList(embeddings) - - - - __all__ = [ "BLTPreTrainedModel", "BLTModel",