From b1ef0868ef6fedf6adc24ff0f62c7f2b0a294f5c Mon Sep 17 00:00:00 2001 From: oweller2 Date: Sun, 22 Jun 2025 00:57:54 -0400 Subject: [PATCH] fix most tests --- .../configuration_modernbert_decoder.py | 4 --- .../modeling_modernbert_decoder.py | 11 +++++-- .../modular_modernbert_decoder.py | 29 ++++++++++++------- .../test_modeling_modernbert_decoder.py | 17 +++++++++++ tests/test_modeling_common.py | 2 ++ utils/check_config_attributes.py | 1 + 6 files changed, 47 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py index c7d77c2c132..4bbeae99c9c 100644 --- a/src/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py @@ -104,8 +104,6 @@ class ModernBertDecoderConfig(PretrainedConfig): Every `global_attn_every_n_layers` layers will use global attention instead of local attention. local_rope_theta (`float`, *optional*): The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`. - num_labels (`int`, *optional*, defaults to 2): - Number of labels for sequence classification. Examples: @@ -159,7 +157,6 @@ class ModernBertDecoderConfig(PretrainedConfig): local_attention=128, global_attn_every_n_layers=3, local_rope_theta=None, - num_labels=2, **kwargs, ): super().__init__( @@ -197,7 +194,6 @@ class ModernBertDecoderConfig(PretrainedConfig): self.local_attention = local_attention self.global_attn_every_n_layers = global_attn_every_n_layers self.local_rope_theta = local_rope_theta - self.num_labels = num_labels def to_dict(self): output = super().to_dict() diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index bb25e7d2888..8753da0057d 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -213,7 +213,8 @@ class ModernBertDecoderAttention(nn.Module): else: # For initial forward pass, start from 0 position_ids = torch.arange(seq_len, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Apply rotary embeddings if past_key_value is not None: @@ -977,7 +978,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode super().__init__(config) self.num_labels = config.num_labels self.model = ModernBertDecoderModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + self.head = ModernBertPredictionHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias) + self.drop = torch.nn.Dropout(config.classifier_dropout) # Initialize weights and apply final processing self.post_init() @@ -1035,7 +1039,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode return_dict=return_dict, ) hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) + hidden_states = self.drop(self.head(hidden_states)) + logits = self.classifier(hidden_states) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index 2c75c120479..847d8766664 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -123,8 +123,6 @@ class ModernBertDecoderConfig(PretrainedConfig): Every `global_attn_every_n_layers` layers will use global attention instead of local attention. local_rope_theta (`float`, *optional*): The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`. - num_labels (`int`, *optional*, defaults to 2): - Number of labels for sequence classification. Examples: @@ -178,7 +176,6 @@ class ModernBertDecoderConfig(PretrainedConfig): local_attention=128, global_attn_every_n_layers=3, local_rope_theta=None, - num_labels=2, **kwargs, ): super().__init__( @@ -216,7 +213,6 @@ class ModernBertDecoderConfig(PretrainedConfig): self.local_attention = local_attention self.global_attn_every_n_layers = global_attn_every_n_layers self.local_rope_theta = local_rope_theta - self.num_labels = num_labels def to_dict(self): output = super().to_dict() @@ -443,7 +439,7 @@ class ModernBertDecoderAttention(nn.Module): qkv = self.Wqkv(hidden_states) qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) - # Create position_ids if None + # Create position_ids that respect padding tokens if not provided if position_ids is None: device = hidden_states.device if past_key_value is not None: @@ -451,9 +447,18 @@ class ModernBertDecoderAttention(nn.Module): cache_length = past_key_value[0].shape[-2] position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device) else: - # For initial forward pass, start from 0 - position_ids = torch.arange(seq_len, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + # For initial forward pass, create position_ids that respect padding + if attention_mask is not None: + # Create cumulative sum of attention_mask to get proper positions + # This ensures padding tokens don't increment position + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + else: + # Fallback: sequential positions + position_ids = torch.arange(seq_len, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Apply rotary embeddings if past_key_value is not None: @@ -1153,7 +1158,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode super().__init__(config) self.num_labels = config.num_labels self.model = ModernBertDecoderModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + self.head = ModernBertPredictionHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias) + self.drop = torch.nn.Dropout(config.classifier_dropout) # Initialize weights and apply final processing self.post_init() @@ -1211,7 +1219,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode return_dict=return_dict, ) hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) + hidden_states = self.drop(self.head(hidden_states)) + logits = self.classifier(hidden_states) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] diff --git a/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py b/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py index 9f0eb4657cb..b4294dfa2a2 100644 --- a/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py +++ b/tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py @@ -22,6 +22,7 @@ from transformers.testing_utils import ( ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...test_modeling_common import _config_zero_init if is_torch_available(): @@ -170,3 +171,19 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase): # Check that loss is computed self.assertIsNotNone(outputs_with_loss.loss) self.assertTrue(isinstance(outputs_with_loss.loss.item(), float)) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM + # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init + if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4e2555b57ef..9441a055b5e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2974,6 +2974,8 @@ class ModelTesterMixin: "ModernBertForTokenClassification", "TimmWrapperForImageClassification", "ModernBertForQuestionAnswering", + "ModernBertDecoderForSequenceClassification", + "ModernBertDecoderForCausalLM", ] special_param_names = [ r"^bit\.", diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index dac58e3104e..bc6d6b8250c 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -279,6 +279,7 @@ SPECIAL_CASES_TO_ALLOW = { "max_position_embeddings", "mlp_bias", "mlp_dropout", + "classifier_activation", ], }