fix most tests

This commit is contained in:
oweller2 2025-06-22 00:57:54 -04:00
parent c865c1d896
commit b1ef0868ef
6 changed files with 47 additions and 17 deletions

View File

@ -104,8 +104,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
local_rope_theta (`float`, *optional*):
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
num_labels (`int`, *optional*, defaults to 2):
Number of labels for sequence classification.
Examples:
@ -159,7 +157,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
local_attention=128,
global_attn_every_n_layers=3,
local_rope_theta=None,
num_labels=2,
**kwargs,
):
super().__init__(
@ -197,7 +194,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
self.local_attention = local_attention
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_rope_theta = local_rope_theta
self.num_labels = num_labels
def to_dict(self):
output = super().to_dict()

View File

@ -213,7 +213,8 @@ class ModernBertDecoderAttention(nn.Module):
else:
# For initial forward pass, start from 0
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Apply rotary embeddings
if past_key_value is not None:
@ -977,7 +978,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBertDecoderModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.head = ModernBertPredictionHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)
# Initialize weights and apply final processing
self.post_init()
@ -1035,7 +1039,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
hidden_states = self.drop(self.head(hidden_states))
logits = self.classifier(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]

View File

@ -123,8 +123,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
local_rope_theta (`float`, *optional*):
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
num_labels (`int`, *optional*, defaults to 2):
Number of labels for sequence classification.
Examples:
@ -178,7 +176,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
local_attention=128,
global_attn_every_n_layers=3,
local_rope_theta=None,
num_labels=2,
**kwargs,
):
super().__init__(
@ -216,7 +213,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
self.local_attention = local_attention
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_rope_theta = local_rope_theta
self.num_labels = num_labels
def to_dict(self):
output = super().to_dict()
@ -443,7 +439,7 @@ class ModernBertDecoderAttention(nn.Module):
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
# Create position_ids if None
# Create position_ids that respect padding tokens if not provided
if position_ids is None:
device = hidden_states.device
if past_key_value is not None:
@ -451,9 +447,18 @@ class ModernBertDecoderAttention(nn.Module):
cache_length = past_key_value[0].shape[-2]
position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device)
else:
# For initial forward pass, start from 0
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# For initial forward pass, create position_ids that respect padding
if attention_mask is not None:
# Create cumulative sum of attention_mask to get proper positions
# This ensures padding tokens don't increment position
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
else:
# Fallback: sequential positions
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Apply rotary embeddings
if past_key_value is not None:
@ -1153,7 +1158,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBertDecoderModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.head = ModernBertPredictionHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)
# Initialize weights and apply final processing
self.post_init()
@ -1211,7 +1219,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
hidden_states = self.drop(self.head(hidden_states))
logits = self.classifier(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]

View File

@ -22,6 +22,7 @@ from transformers.testing_utils import (
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
from ...test_modeling_common import _config_zero_init
if is_torch_available():
@ -170,3 +171,19 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
# Check that loss is computed
self.assertIsNotNone(outputs_with_loss.loss)
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
# The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

View File

@ -2974,6 +2974,8 @@ class ModelTesterMixin:
"ModernBertForTokenClassification",
"TimmWrapperForImageClassification",
"ModernBertForQuestionAnswering",
"ModernBertDecoderForSequenceClassification",
"ModernBertDecoderForCausalLM",
]
special_param_names = [
r"^bit\.",

View File

@ -279,6 +279,7 @@ SPECIAL_CASES_TO_ALLOW = {
"max_position_embeddings",
"mlp_bias",
"mlp_dropout",
"classifier_activation",
],
}