diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 8753da0057d..7c53532733c 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -129,6 +129,31 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask + # Apply attention mask (for padding) if provided + if attention_mask is not None: + # Check if the attention mask has the correct kv dimension + if attention_mask.dim() == 4 and attention_mask.shape[-1] != kv_seq_len: + # The mask was pre-computed but doesn't match our kv sequence length + # This happens when we have past key values that extend the sequence + current_mask_kv_len = attention_mask.shape[-1] + + if current_mask_kv_len < kv_seq_len: + # Need to extend the mask to cover past keys + # Past keys should be allowed (not masked), so we pad with zeros + past_length = kv_seq_len - current_mask_kv_len + + # Create padding for past positions (zeros = allow attention) + past_padding = torch.zeros( + (attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Concatenate past padding with current mask + attention_mask = torch.cat([past_padding, attention_mask], dim=-1) + + attn_weights = attn_weights + attention_mask + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) @@ -203,7 +228,7 @@ class ModernBertDecoderAttention(nn.Module): qkv = self.Wqkv(hidden_states) qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) - # Create position_ids if None + # Fallback: create position_ids only if not provided if position_ids is None: device = hidden_states.device if past_key_value is not None: @@ -213,8 +238,11 @@ class ModernBertDecoderAttention(nn.Module): else: # For initial forward pass, start from 0 position_ids = torch.arange(seq_len, dtype=torch.long, device=device) - if position_ids.dim() == 1: - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Ensure position_ids has the right shape if provided externally + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Apply rotary embeddings if past_key_value is not None: @@ -611,24 +639,19 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): else: batch_size, seq_length = inputs_embeds.shape[:2] - # Create cache_position - this is the sequence positions for the current tokens + # Create cache_position - this represents the actual tensor positions (not semantic positions) + # Calculate past seen tokens first if past_key_values is None: past_seen_tokens = 0 else: past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values[0] is not None else 0 - # Create cache_position using position_ids if available (to respect padding) - if position_ids is not None: - # Use the actual positions from position_ids for cache_position - cache_position = ( - position_ids[0] if position_ids.shape[0] > 0 else torch.arange(seq_length, device=position_ids.device) - ) - else: - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + seq_length, - device=input_ids.device if input_ids is not None else inputs_embeds.device, - ) + # Always use sequential positions like other HF models, regardless of padding + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + seq_length, + device=input_ids.device if input_ids is not None else inputs_embeds.device, + ) # Create position_ids that respect padding tokens if not provided if position_ids is None: @@ -639,7 +662,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): # Create cumulative sum of attention_mask to get proper positions # This ensures padding tokens don't increment position position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) + position_ids.masked_fill_(attention_mask == 0, 1) else: # Fallback: sequential positions position_ids = torch.arange(seq_length, dtype=torch.long, device=device) @@ -656,7 +679,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): current_positions = past_real_tokens + torch.arange(seq_length, device=device) # Only increment for non-padding tokens in current sequence current_mask = attention_mask[..., -seq_length:] - position_ids = current_positions.masked_fill(current_mask == 0, 0) + position_ids = current_positions.masked_fill(current_mask == 0, 1) else: # Fallback: continue sequentially position_ids = torch.arange( diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index 847d8766664..3eb121b27d2 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -365,6 +365,31 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask + # Apply attention mask (for padding) if provided + if attention_mask is not None: + # Check if the attention mask has the correct kv dimension + if attention_mask.dim() == 4 and attention_mask.shape[-1] != kv_seq_len: + # The mask was pre-computed but doesn't match our kv sequence length + # This happens when we have past key values that extend the sequence + current_mask_kv_len = attention_mask.shape[-1] + + if current_mask_kv_len < kv_seq_len: + # Need to extend the mask to cover past keys + # Past keys should be allowed (not masked), so we pad with zeros + past_length = kv_seq_len - current_mask_kv_len + + # Create padding for past positions (zeros = allow attention) + past_padding = torch.zeros( + (attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Concatenate past padding with current mask + attention_mask = torch.cat([past_padding, attention_mask], dim=-1) + + attn_weights = attn_weights + attention_mask + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) @@ -439,7 +464,7 @@ class ModernBertDecoderAttention(nn.Module): qkv = self.Wqkv(hidden_states) qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) - # Create position_ids that respect padding tokens if not provided + # Fallback: create position_ids only if not provided if position_ids is None: device = hidden_states.device if past_key_value is not None: @@ -447,18 +472,13 @@ class ModernBertDecoderAttention(nn.Module): cache_length = past_key_value[0].shape[-2] position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device) else: - # For initial forward pass, create position_ids that respect padding - if attention_mask is not None: - # Create cumulative sum of attention_mask to get proper positions - # This ensures padding tokens don't increment position - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) - else: - # Fallback: sequential positions - position_ids = torch.arange(seq_len, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) - if position_ids.dim() == 1: - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + # For initial forward pass, start from 0 + position_ids = torch.arange(seq_len, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Ensure position_ids has the right shape if provided externally + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Apply rotary embeddings if past_key_value is not None: @@ -791,24 +811,19 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): else: batch_size, seq_length = inputs_embeds.shape[:2] - # Create cache_position - this is the sequence positions for the current tokens + # Create cache_position - this represents the actual tensor positions (not semantic positions) + # Calculate past seen tokens first if past_key_values is None: past_seen_tokens = 0 else: past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values[0] is not None else 0 - # Create cache_position using position_ids if available (to respect padding) - if position_ids is not None: - # Use the actual positions from position_ids for cache_position - cache_position = ( - position_ids[0] if position_ids.shape[0] > 0 else torch.arange(seq_length, device=position_ids.device) - ) - else: - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + seq_length, - device=input_ids.device if input_ids is not None else inputs_embeds.device, - ) + # Always use sequential positions like other HF models, regardless of padding + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + seq_length, + device=input_ids.device if input_ids is not None else inputs_embeds.device, + ) # Create position_ids that respect padding tokens if not provided if position_ids is None: @@ -819,7 +834,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): # Create cumulative sum of attention_mask to get proper positions # This ensures padding tokens don't increment position position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) + position_ids.masked_fill_(attention_mask == 0, 1) else: # Fallback: sequential positions position_ids = torch.arange(seq_length, dtype=torch.long, device=device) @@ -836,7 +851,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): current_positions = past_real_tokens + torch.arange(seq_length, device=device) # Only increment for non-padding tokens in current sequence current_mask = attention_mask[..., -seq_length:] - position_ids = current_positions.masked_fill(current_mask == 0, 0) + position_ids = current_positions.masked_fill(current_mask == 0, 1) else: # Fallback: continue sequentially position_ids = torch.arange( diff --git a/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py b/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py index b4294dfa2a2..9f0eb4657cb 100644 --- a/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py +++ b/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py @@ -22,7 +22,6 @@ from transformers.testing_utils import ( ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -from ...test_modeling_common import _config_zero_init if is_torch_available(): @@ -171,19 +170,3 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase): # Check that loss is computed self.assertIsNotNone(outputs_with_loss.loss) self.assertTrue(isinstance(outputs_with_loss.loss.item(), float)) - - def test_initialization(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - configs_no_init = _config_zero_init(config) - for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - for name, param in model.named_parameters(): - # The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM - # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init - if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"): - self.assertIn( - ((param.data.mean() * 1e9).round() / 1e9).item(), - [0.0, 1.0], - msg=f"Parameter {name} of model {model_class} seems not properly initialized", - )