mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixed tests
This commit is contained in:
parent
b1ef0868ef
commit
9de1db1e8b
@ -129,6 +129,31 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Apply attention mask (for padding) if provided
|
||||
if attention_mask is not None:
|
||||
# Check if the attention mask has the correct kv dimension
|
||||
if attention_mask.dim() == 4 and attention_mask.shape[-1] != kv_seq_len:
|
||||
# The mask was pre-computed but doesn't match our kv sequence length
|
||||
# This happens when we have past key values that extend the sequence
|
||||
current_mask_kv_len = attention_mask.shape[-1]
|
||||
|
||||
if current_mask_kv_len < kv_seq_len:
|
||||
# Need to extend the mask to cover past keys
|
||||
# Past keys should be allowed (not masked), so we pad with zeros
|
||||
past_length = kv_seq_len - current_mask_kv_len
|
||||
|
||||
# Create padding for past positions (zeros = allow attention)
|
||||
past_padding = torch.zeros(
|
||||
(attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Concatenate past padding with current mask
|
||||
attention_mask = torch.cat([past_padding, attention_mask], dim=-1)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
@ -203,7 +228,7 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
qkv = self.Wqkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
|
||||
# Create position_ids if None
|
||||
# Fallback: create position_ids only if not provided
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
if past_key_value is not None:
|
||||
@ -213,8 +238,11 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
else:
|
||||
# For initial forward pass, start from 0
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Ensure position_ids has the right shape if provided externally
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if past_key_value is not None:
|
||||
@ -611,24 +639,19 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
else:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
|
||||
# Create cache_position - this is the sequence positions for the current tokens
|
||||
# Create cache_position - this represents the actual tensor positions (not semantic positions)
|
||||
# Calculate past seen tokens first
|
||||
if past_key_values is None:
|
||||
past_seen_tokens = 0
|
||||
else:
|
||||
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values[0] is not None else 0
|
||||
|
||||
# Create cache_position using position_ids if available (to respect padding)
|
||||
if position_ids is not None:
|
||||
# Use the actual positions from position_ids for cache_position
|
||||
cache_position = (
|
||||
position_ids[0] if position_ids.shape[0] > 0 else torch.arange(seq_length, device=position_ids.device)
|
||||
)
|
||||
else:
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + seq_length,
|
||||
device=input_ids.device if input_ids is not None else inputs_embeds.device,
|
||||
)
|
||||
# Always use sequential positions like other HF models, regardless of padding
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + seq_length,
|
||||
device=input_ids.device if input_ids is not None else inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Create position_ids that respect padding tokens if not provided
|
||||
if position_ids is None:
|
||||
@ -639,7 +662,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
# Create cumulative sum of attention_mask to get proper positions
|
||||
# This ensures padding tokens don't increment position
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 0)
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
# Fallback: sequential positions
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
@ -656,7 +679,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
current_positions = past_real_tokens + torch.arange(seq_length, device=device)
|
||||
# Only increment for non-padding tokens in current sequence
|
||||
current_mask = attention_mask[..., -seq_length:]
|
||||
position_ids = current_positions.masked_fill(current_mask == 0, 0)
|
||||
position_ids = current_positions.masked_fill(current_mask == 0, 1)
|
||||
else:
|
||||
# Fallback: continue sequentially
|
||||
position_ids = torch.arange(
|
||||
|
@ -365,6 +365,31 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Apply attention mask (for padding) if provided
|
||||
if attention_mask is not None:
|
||||
# Check if the attention mask has the correct kv dimension
|
||||
if attention_mask.dim() == 4 and attention_mask.shape[-1] != kv_seq_len:
|
||||
# The mask was pre-computed but doesn't match our kv sequence length
|
||||
# This happens when we have past key values that extend the sequence
|
||||
current_mask_kv_len = attention_mask.shape[-1]
|
||||
|
||||
if current_mask_kv_len < kv_seq_len:
|
||||
# Need to extend the mask to cover past keys
|
||||
# Past keys should be allowed (not masked), so we pad with zeros
|
||||
past_length = kv_seq_len - current_mask_kv_len
|
||||
|
||||
# Create padding for past positions (zeros = allow attention)
|
||||
past_padding = torch.zeros(
|
||||
(attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Concatenate past padding with current mask
|
||||
attention_mask = torch.cat([past_padding, attention_mask], dim=-1)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
@ -439,7 +464,7 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
qkv = self.Wqkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
|
||||
# Create position_ids that respect padding tokens if not provided
|
||||
# Fallback: create position_ids only if not provided
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
if past_key_value is not None:
|
||||
@ -447,18 +472,13 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
cache_length = past_key_value[0].shape[-2]
|
||||
position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device)
|
||||
else:
|
||||
# For initial forward pass, create position_ids that respect padding
|
||||
if attention_mask is not None:
|
||||
# Create cumulative sum of attention_mask to get proper positions
|
||||
# This ensures padding tokens don't increment position
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 0)
|
||||
else:
|
||||
# Fallback: sequential positions
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
# For initial forward pass, start from 0
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Ensure position_ids has the right shape if provided externally
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if past_key_value is not None:
|
||||
@ -791,24 +811,19 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
else:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
|
||||
# Create cache_position - this is the sequence positions for the current tokens
|
||||
# Create cache_position - this represents the actual tensor positions (not semantic positions)
|
||||
# Calculate past seen tokens first
|
||||
if past_key_values is None:
|
||||
past_seen_tokens = 0
|
||||
else:
|
||||
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values[0] is not None else 0
|
||||
|
||||
# Create cache_position using position_ids if available (to respect padding)
|
||||
if position_ids is not None:
|
||||
# Use the actual positions from position_ids for cache_position
|
||||
cache_position = (
|
||||
position_ids[0] if position_ids.shape[0] > 0 else torch.arange(seq_length, device=position_ids.device)
|
||||
)
|
||||
else:
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + seq_length,
|
||||
device=input_ids.device if input_ids is not None else inputs_embeds.device,
|
||||
)
|
||||
# Always use sequential positions like other HF models, regardless of padding
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + seq_length,
|
||||
device=input_ids.device if input_ids is not None else inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Create position_ids that respect padding tokens if not provided
|
||||
if position_ids is None:
|
||||
@ -819,7 +834,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
# Create cumulative sum of attention_mask to get proper positions
|
||||
# This ensures padding tokens don't increment position
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 0)
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
# Fallback: sequential positions
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
@ -836,7 +851,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
current_positions = past_real_tokens + torch.arange(seq_length, device=device)
|
||||
# Only increment for non-padding tokens in current sequence
|
||||
current_mask = attention_mask[..., -seq_length:]
|
||||
position_ids = current_positions.masked_fill(current_mask == 0, 0)
|
||||
position_ids = current_positions.masked_fill(current_mask == 0, 1)
|
||||
else:
|
||||
# Fallback: continue sequentially
|
||||
position_ids = torch.arange(
|
||||
|
@ -22,7 +22,6 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
from ...test_modeling_common import _config_zero_init
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -171,19 +170,3 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
|
||||
# Check that loss is computed
|
||||
self.assertIsNotNone(outputs_with_loss.loss)
|
||||
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
# The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM
|
||||
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
||||
if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"):
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user