fixed tests

This commit is contained in:
oweller2 2025-06-22 01:39:12 -04:00
parent b1ef0868ef
commit 9de1db1e8b
3 changed files with 84 additions and 63 deletions

View File

@ -129,6 +129,31 @@ def eager_attention_forward(
attn_weights = attn_weights + causal_mask
# Apply attention mask (for padding) if provided
if attention_mask is not None:
# Check if the attention mask has the correct kv dimension
if attention_mask.dim() == 4 and attention_mask.shape[-1] != kv_seq_len:
# The mask was pre-computed but doesn't match our kv sequence length
# This happens when we have past key values that extend the sequence
current_mask_kv_len = attention_mask.shape[-1]
if current_mask_kv_len < kv_seq_len:
# Need to extend the mask to cover past keys
# Past keys should be allowed (not masked), so we pad with zeros
past_length = kv_seq_len - current_mask_kv_len
# Create padding for past positions (zeros = allow attention)
past_padding = torch.zeros(
(attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Concatenate past padding with current mask
attention_mask = torch.cat([past_padding, attention_mask], dim=-1)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
@ -203,7 +228,7 @@ class ModernBertDecoderAttention(nn.Module):
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
# Create position_ids if None
# Fallback: create position_ids only if not provided
if position_ids is None:
device = hidden_states.device
if past_key_value is not None:
@ -213,8 +238,11 @@ class ModernBertDecoderAttention(nn.Module):
else:
# For initial forward pass, start from 0
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Ensure position_ids has the right shape if provided externally
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Apply rotary embeddings
if past_key_value is not None:
@ -611,24 +639,19 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
else:
batch_size, seq_length = inputs_embeds.shape[:2]
# Create cache_position - this is the sequence positions for the current tokens
# Create cache_position - this represents the actual tensor positions (not semantic positions)
# Calculate past seen tokens first
if past_key_values is None:
past_seen_tokens = 0
else:
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values[0] is not None else 0
# Create cache_position using position_ids if available (to respect padding)
if position_ids is not None:
# Use the actual positions from position_ids for cache_position
cache_position = (
position_ids[0] if position_ids.shape[0] > 0 else torch.arange(seq_length, device=position_ids.device)
)
else:
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + seq_length,
device=input_ids.device if input_ids is not None else inputs_embeds.device,
)
# Always use sequential positions like other HF models, regardless of padding
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + seq_length,
device=input_ids.device if input_ids is not None else inputs_embeds.device,
)
# Create position_ids that respect padding tokens if not provided
if position_ids is None:
@ -639,7 +662,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
# Create cumulative sum of attention_mask to get proper positions
# This ensures padding tokens don't increment position
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
position_ids.masked_fill_(attention_mask == 0, 1)
else:
# Fallback: sequential positions
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
@ -656,7 +679,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
current_positions = past_real_tokens + torch.arange(seq_length, device=device)
# Only increment for non-padding tokens in current sequence
current_mask = attention_mask[..., -seq_length:]
position_ids = current_positions.masked_fill(current_mask == 0, 0)
position_ids = current_positions.masked_fill(current_mask == 0, 1)
else:
# Fallback: continue sequentially
position_ids = torch.arange(

View File

@ -365,6 +365,31 @@ def eager_attention_forward(
attn_weights = attn_weights + causal_mask
# Apply attention mask (for padding) if provided
if attention_mask is not None:
# Check if the attention mask has the correct kv dimension
if attention_mask.dim() == 4 and attention_mask.shape[-1] != kv_seq_len:
# The mask was pre-computed but doesn't match our kv sequence length
# This happens when we have past key values that extend the sequence
current_mask_kv_len = attention_mask.shape[-1]
if current_mask_kv_len < kv_seq_len:
# Need to extend the mask to cover past keys
# Past keys should be allowed (not masked), so we pad with zeros
past_length = kv_seq_len - current_mask_kv_len
# Create padding for past positions (zeros = allow attention)
past_padding = torch.zeros(
(attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Concatenate past padding with current mask
attention_mask = torch.cat([past_padding, attention_mask], dim=-1)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
@ -439,7 +464,7 @@ class ModernBertDecoderAttention(nn.Module):
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
# Create position_ids that respect padding tokens if not provided
# Fallback: create position_ids only if not provided
if position_ids is None:
device = hidden_states.device
if past_key_value is not None:
@ -447,18 +472,13 @@ class ModernBertDecoderAttention(nn.Module):
cache_length = past_key_value[0].shape[-2]
position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device)
else:
# For initial forward pass, create position_ids that respect padding
if attention_mask is not None:
# Create cumulative sum of attention_mask to get proper positions
# This ensures padding tokens don't increment position
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
else:
# Fallback: sequential positions
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# For initial forward pass, start from 0
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Ensure position_ids has the right shape if provided externally
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Apply rotary embeddings
if past_key_value is not None:
@ -791,24 +811,19 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
else:
batch_size, seq_length = inputs_embeds.shape[:2]
# Create cache_position - this is the sequence positions for the current tokens
# Create cache_position - this represents the actual tensor positions (not semantic positions)
# Calculate past seen tokens first
if past_key_values is None:
past_seen_tokens = 0
else:
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values[0] is not None else 0
# Create cache_position using position_ids if available (to respect padding)
if position_ids is not None:
# Use the actual positions from position_ids for cache_position
cache_position = (
position_ids[0] if position_ids.shape[0] > 0 else torch.arange(seq_length, device=position_ids.device)
)
else:
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + seq_length,
device=input_ids.device if input_ids is not None else inputs_embeds.device,
)
# Always use sequential positions like other HF models, regardless of padding
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + seq_length,
device=input_ids.device if input_ids is not None else inputs_embeds.device,
)
# Create position_ids that respect padding tokens if not provided
if position_ids is None:
@ -819,7 +834,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
# Create cumulative sum of attention_mask to get proper positions
# This ensures padding tokens don't increment position
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
position_ids.masked_fill_(attention_mask == 0, 1)
else:
# Fallback: sequential positions
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
@ -836,7 +851,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
current_positions = past_real_tokens + torch.arange(seq_length, device=device)
# Only increment for non-padding tokens in current sequence
current_mask = attention_mask[..., -seq_length:]
position_ids = current_positions.masked_fill(current_mask == 0, 0)
position_ids = current_positions.masked_fill(current_mask == 0, 1)
else:
# Fallback: continue sequentially
position_ids = torch.arange(

View File

@ -22,7 +22,6 @@ from transformers.testing_utils import (
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
from ...test_modeling_common import _config_zero_init
if is_torch_available():
@ -171,19 +170,3 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
# Check that loss is computed
self.assertIsNotNone(outputs_with_loss.loss)
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
# The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)