mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
clean up patcher helpers further
This commit is contained in:
parent
5c218735fb
commit
09ad5a51ea
@ -40,14 +40,11 @@ from .configuration_blt import (
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
@ -59,7 +56,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP
|
||||
class BLTMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -75,7 +71,6 @@ class BLTMLP(nn.Module):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -135,7 +130,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
return q_rot.type_as(q), k_rot.type_as(k)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||
class BLTRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
@ -235,7 +229,6 @@ class BLTTransformerLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT
|
||||
class BLTSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx: int):
|
||||
super().__init__()
|
||||
@ -316,59 +309,66 @@ class BLTSelfAttention(nn.Module):
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
|
||||
def check_non_zero_after_zero(tensor):
|
||||
zero_mask = tensor == 0
|
||||
shifted_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
|
||||
zero_mask[:, :-1],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
non_zero_after_zero = (tensor != 0) & shifted_mask
|
||||
return non_zero_after_zero.any()
|
||||
|
||||
|
||||
def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0):
|
||||
primes = [
|
||||
1000000007,
|
||||
5915587277,
|
||||
1500450271,
|
||||
3267000013,
|
||||
5754853343,
|
||||
4093082899,
|
||||
9576890767,
|
||||
3628273133,
|
||||
2860486313,
|
||||
5463458053,
|
||||
3367900313,
|
||||
1000000007, 5915587277, 1500450271, 3267000013, 5754853343,
|
||||
4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313,
|
||||
]
|
||||
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device)
|
||||
prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])])
|
||||
powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
|
||||
prime_powers = prime ** powers
|
||||
return torch.sum(token_tensor * prime_powers, dim=-1)
|
||||
|
||||
|
||||
def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
|
||||
"""
|
||||
Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash].
|
||||
|
||||
expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab.
|
||||
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
|
||||
|
||||
Note: max hash can make a big difference on the number of collisions.
|
||||
"""
|
||||
"""Hash token groups and map to range [0, max_hash]."""
|
||||
with torch.no_grad():
|
||||
batch_size, seq_len = token_ids.shape
|
||||
prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
|
||||
token_ids = torch.cat([prefix, token_ids], dim=1)
|
||||
windows = token_ids.unfold(1, group_size, 1)
|
||||
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
|
||||
# Add padding for sliding window
|
||||
padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
|
||||
padded_tokens = torch.cat([padding, token_ids], dim=1)
|
||||
|
||||
# Create sliding windows and compute hashes
|
||||
windows = padded_tokens.unfold(1, group_size, 1)
|
||||
hashes = rolling_polynomial_hash(windows, hash_func_nb)
|
||||
hash_values_range = hashes % max_hash
|
||||
hash_values_range.requires_grad = False
|
||||
return hash_values_range
|
||||
hash_values = hashes % max_hash
|
||||
|
||||
hash_values.requires_grad = False
|
||||
return hash_values
|
||||
|
||||
|
||||
def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list):
|
||||
"""Initialize hash-based token embeddings for the BLT encoder."""
|
||||
num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size)
|
||||
embeddings = [
|
||||
nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim)
|
||||
for _ in range(num_embeddings)
|
||||
]
|
||||
return nn.ModuleList(embeddings)
|
||||
|
||||
|
||||
def compute_hash_embeddings(
|
||||
local_encoder_tokens: torch.Tensor,
|
||||
local_encoder,
|
||||
encoder_hash_tok_embedding: nn.ModuleList,
|
||||
encoder_hash_byte_group_nb_functions: int,
|
||||
encoder_hash_byte_group_size: list,
|
||||
encoder_hash_byte_group_vocab: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute token embeddings enhanced with hash-based embeddings."""
|
||||
embeddings = local_encoder.embed_tokens(local_encoder_tokens)
|
||||
embedding_idx = 0
|
||||
for func_nb in range(encoder_hash_byte_group_nb_functions):
|
||||
for group_size in encoder_hash_byte_group_size:
|
||||
hash_ids = byte_group_hash_function(
|
||||
local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab
|
||||
)
|
||||
embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids)
|
||||
embedding_idx += 1
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def _prepare_patch_cross_attention_mask(
|
||||
@ -455,34 +455,47 @@ def _prepare_patch_cross_attention_mask(
|
||||
|
||||
return cross_attention_mask, full_text_row_masked_out_mask
|
||||
|
||||
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
|
||||
#TODO: refactor to be more readable
|
||||
|
||||
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
|
||||
"""
|
||||
Splits patch lengths into smaller segments if they exceed `max_patch_length`.
|
||||
Pads the result to uniform length across the batch.
|
||||
|
||||
Args:
|
||||
patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
|
||||
max_patch_length (int, optional): Maximum allowed length per patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
|
||||
"""
|
||||
if max_patch_length is None:
|
||||
return patch_lengths
|
||||
|
||||
batch_size = patch_lengths.size(0)
|
||||
split_all = []
|
||||
max_len = 0
|
||||
processed = []
|
||||
|
||||
for seq in patch_lengths:
|
||||
splits = []
|
||||
for length in seq[seq > 0]:
|
||||
# Split long patches into max_patch_length chunks
|
||||
full, rem = divmod(length.item(), max_patch_length)
|
||||
splits.extend([max_patch_length] * full + ([rem] if rem else []))
|
||||
split_all.append(splits)
|
||||
max_len = max(max_len, len(splits))
|
||||
length = length.item()
|
||||
full_chunks, remainder = divmod(length, max_patch_length)
|
||||
splits.extend([max_patch_length] * full_chunks)
|
||||
if remainder:
|
||||
splits.append(remainder)
|
||||
processed.append(splits)
|
||||
|
||||
# Pad sequences to the maximum length
|
||||
# Find max length to pad to
|
||||
max_len = max(len(splits) for splits in processed)
|
||||
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
|
||||
for i, splits in enumerate(split_all):
|
||||
|
||||
for i, splits in enumerate(processed):
|
||||
if splits:
|
||||
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
|
||||
|
||||
# Trim trailing columns that are all zeros
|
||||
last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
|
||||
if last_non_zero < padded.shape[1]:
|
||||
padded = padded[:, :padded.shape[1] - last_non_zero]
|
||||
# Trim zero columns
|
||||
if (padded != 0).any(dim=0).sum() < padded.shape[1]:
|
||||
last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
|
||||
padded = padded[:, :last_nonzero]
|
||||
|
||||
return padded
|
||||
|
||||
@ -627,12 +640,11 @@ class BLTLocalDecoder(nn.Module):
|
||||
def __init__(self, config: BLTLocalDecoderConfig):
|
||||
super().__init__()
|
||||
|
||||
# Extract config values to instance attributes
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vocab_size=config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.dropout = config.dropout
|
||||
self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove
|
||||
self.cross_attn_decoder = True
|
||||
self.cross_attn_all_layers = config.cross_attn_all_layers
|
||||
self.cross_attn_k = config.cross_attn_k
|
||||
|
||||
@ -655,11 +667,7 @@ class BLTLocalDecoder(nn.Module):
|
||||
BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size)
|
||||
)
|
||||
|
||||
self.lm_head = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.vocab_size,
|
||||
bias=False,
|
||||
)
|
||||
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False,)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -851,48 +859,6 @@ class BLTGlobalTransformer(nn.Module):
|
||||
return hidden_states, cache
|
||||
|
||||
|
||||
def compute_hash_embeddings(
|
||||
local_encoder_tokens: torch.Tensor,
|
||||
local_encoder,
|
||||
encoder_hash_tok_embedding: nn.ModuleList,
|
||||
encoder_hash_byte_group_nb_functions: int,
|
||||
encoder_hash_byte_group_size: list,
|
||||
encoder_hash_byte_group_vocab: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute embeddings using hash token embeddings.
|
||||
|
||||
Args:
|
||||
local_encoder_tokens: Input tokens tensor
|
||||
local_encoder: Encoder object with embed_tokens method
|
||||
encoder_hash_tok_embedding: ModuleList of hash token embeddings
|
||||
encoder_hash_byte_group_nb_functions: Number of hash functions
|
||||
encoder_hash_byte_group_size: List of byte group sizes
|
||||
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Combined embeddings
|
||||
"""
|
||||
if encoder_hash_tok_embedding is None:
|
||||
return None
|
||||
|
||||
local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens)
|
||||
|
||||
i = 0
|
||||
for func_nb in range(encoder_hash_byte_group_nb_functions):
|
||||
for byte_group_size in encoder_hash_byte_group_size:
|
||||
hash_ids = byte_group_hash_function(
|
||||
local_encoder_tokens,
|
||||
byte_group_size,
|
||||
hash_func_nb=func_nb,
|
||||
max_hash=encoder_hash_byte_group_vocab,
|
||||
)
|
||||
hash_tok_embedding = encoder_hash_tok_embedding[i]
|
||||
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
|
||||
i += 1
|
||||
|
||||
assert i == len(encoder_hash_tok_embedding)
|
||||
return local_encoder_embeds
|
||||
|
||||
|
||||
class BLTPreTrainedModel(PreTrainedModel):
|
||||
@ -982,7 +948,7 @@ class BLTModel(BLTPreTrainedModel):
|
||||
|
||||
# Patcher initialization
|
||||
if self.patch_in_forward:
|
||||
self.patcher = BLTPatcher(config)
|
||||
self.patcher = BLTPatcher(config.patcher_config)
|
||||
self.patcher.eval()
|
||||
for param in self.patcher.parameters():
|
||||
param.requires_grad = False
|
||||
@ -1076,8 +1042,8 @@ class BLTModel(BLTPreTrainedModel):
|
||||
|
||||
|
||||
class BLTPatcher(BLTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config.patcher_config)
|
||||
def __init__(self, config: BLTPatcherConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.rotary_emb = BLTRotaryEmbedding(config=self.config)
|
||||
|
||||
@ -1223,34 +1189,6 @@ class BLTPatcher(BLTPreTrainedModel):
|
||||
|
||||
return patch_lengths
|
||||
|
||||
|
||||
def init_hash_embeddings(
|
||||
config,
|
||||
local_encoder_dim: int,
|
||||
encoder_hash_byte_group_size: list,
|
||||
):
|
||||
"""Initialize hash-based token embeddings for the BLT encoder."""
|
||||
if config.encoder_hash_byte_group_size is None:
|
||||
return None
|
||||
|
||||
embeddings = []
|
||||
emb_dim = local_encoder_dim
|
||||
encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
|
||||
|
||||
for _ in range(config.encoder_hash_byte_group_nb_functions):
|
||||
for _ in encoder_hash_byte_group_size:
|
||||
embeddings.append(
|
||||
nn.Embedding(
|
||||
encoder_hash_byte_group_vocab,
|
||||
emb_dim,
|
||||
)
|
||||
)
|
||||
|
||||
return nn.ModuleList(embeddings)
|
||||
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BLTPreTrainedModel",
|
||||
"BLTModel",
|
||||
|
Loading…
Reference in New Issue
Block a user