mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix most tests
This commit is contained in:
parent
c865c1d896
commit
b1ef0868ef
@ -104,8 +104,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
||||
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
||||
local_rope_theta (`float`, *optional*):
|
||||
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
||||
num_labels (`int`, *optional*, defaults to 2):
|
||||
Number of labels for sequence classification.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -159,7 +157,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
||||
local_attention=128,
|
||||
global_attn_every_n_layers=3,
|
||||
local_rope_theta=None,
|
||||
num_labels=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -197,7 +194,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
||||
self.local_attention = local_attention
|
||||
self.global_attn_every_n_layers = global_attn_every_n_layers
|
||||
self.local_rope_theta = local_rope_theta
|
||||
self.num_labels = num_labels
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
|
@ -213,7 +213,8 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
else:
|
||||
# For initial forward pass, start from 0
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if past_key_value is not None:
|
||||
@ -977,7 +978,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = ModernBertDecoderModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1035,7 +1039,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
hidden_states = self.drop(self.head(hidden_states))
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
|
@ -123,8 +123,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
||||
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
||||
local_rope_theta (`float`, *optional*):
|
||||
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
||||
num_labels (`int`, *optional*, defaults to 2):
|
||||
Number of labels for sequence classification.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -178,7 +176,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
||||
local_attention=128,
|
||||
global_attn_every_n_layers=3,
|
||||
local_rope_theta=None,
|
||||
num_labels=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -216,7 +213,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
||||
self.local_attention = local_attention
|
||||
self.global_attn_every_n_layers = global_attn_every_n_layers
|
||||
self.local_rope_theta = local_rope_theta
|
||||
self.num_labels = num_labels
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
@ -443,7 +439,7 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
qkv = self.Wqkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
|
||||
# Create position_ids if None
|
||||
# Create position_ids that respect padding tokens if not provided
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
if past_key_value is not None:
|
||||
@ -451,9 +447,18 @@ class ModernBertDecoderAttention(nn.Module):
|
||||
cache_length = past_key_value[0].shape[-2]
|
||||
position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device)
|
||||
else:
|
||||
# For initial forward pass, start from 0
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
# For initial forward pass, create position_ids that respect padding
|
||||
if attention_mask is not None:
|
||||
# Create cumulative sum of attention_mask to get proper positions
|
||||
# This ensures padding tokens don't increment position
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 0)
|
||||
else:
|
||||
# Fallback: sequential positions
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if past_key_value is not None:
|
||||
@ -1153,7 +1158,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = ModernBertDecoderModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1211,7 +1219,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
hidden_states = self.drop(self.head(hidden_states))
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
|
@ -22,6 +22,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
from ...test_modeling_common import _config_zero_init
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -170,3 +171,19 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
|
||||
# Check that loss is computed
|
||||
self.assertIsNotNone(outputs_with_loss.loss)
|
||||
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
# The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM
|
||||
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
||||
if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"):
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
@ -2974,6 +2974,8 @@ class ModelTesterMixin:
|
||||
"ModernBertForTokenClassification",
|
||||
"TimmWrapperForImageClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
"ModernBertDecoderForSequenceClassification",
|
||||
"ModernBertDecoderForCausalLM",
|
||||
]
|
||||
special_param_names = [
|
||||
r"^bit\.",
|
||||
|
@ -279,6 +279,7 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"max_position_embeddings",
|
||||
"mlp_bias",
|
||||
"mlp_dropout",
|
||||
"classifier_activation",
|
||||
],
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user