mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
fix most tests
This commit is contained in:
parent
c865c1d896
commit
b1ef0868ef
@ -104,8 +104,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
|||||||
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
||||||
local_rope_theta (`float`, *optional*):
|
local_rope_theta (`float`, *optional*):
|
||||||
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
||||||
num_labels (`int`, *optional*, defaults to 2):
|
|
||||||
Number of labels for sequence classification.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -159,7 +157,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
|||||||
local_attention=128,
|
local_attention=128,
|
||||||
global_attn_every_n_layers=3,
|
global_attn_every_n_layers=3,
|
||||||
local_rope_theta=None,
|
local_rope_theta=None,
|
||||||
num_labels=2,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -197,7 +194,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
|||||||
self.local_attention = local_attention
|
self.local_attention = local_attention
|
||||||
self.global_attn_every_n_layers = global_attn_every_n_layers
|
self.global_attn_every_n_layers = global_attn_every_n_layers
|
||||||
self.local_rope_theta = local_rope_theta
|
self.local_rope_theta = local_rope_theta
|
||||||
self.num_labels = num_labels
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
output = super().to_dict()
|
output = super().to_dict()
|
||||||
|
@ -213,7 +213,8 @@ class ModernBertDecoderAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# For initial forward pass, start from 0
|
# For initial forward pass, start from 0
|
||||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
if position_ids.dim() == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -977,7 +978,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.model = ModernBertDecoderModel(config)
|
self.model = ModernBertDecoderModel(config)
|
||||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -1035,7 +1039,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
logits = self.score(hidden_states)
|
hidden_states = self.drop(self.head(hidden_states))
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
batch_size, sequence_length = input_ids.shape[:2]
|
batch_size, sequence_length = input_ids.shape[:2]
|
||||||
|
@ -123,8 +123,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
|||||||
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
||||||
local_rope_theta (`float`, *optional*):
|
local_rope_theta (`float`, *optional*):
|
||||||
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
||||||
num_labels (`int`, *optional*, defaults to 2):
|
|
||||||
Number of labels for sequence classification.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -178,7 +176,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
|||||||
local_attention=128,
|
local_attention=128,
|
||||||
global_attn_every_n_layers=3,
|
global_attn_every_n_layers=3,
|
||||||
local_rope_theta=None,
|
local_rope_theta=None,
|
||||||
num_labels=2,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -216,7 +213,6 @@ class ModernBertDecoderConfig(PretrainedConfig):
|
|||||||
self.local_attention = local_attention
|
self.local_attention = local_attention
|
||||||
self.global_attn_every_n_layers = global_attn_every_n_layers
|
self.global_attn_every_n_layers = global_attn_every_n_layers
|
||||||
self.local_rope_theta = local_rope_theta
|
self.local_rope_theta = local_rope_theta
|
||||||
self.num_labels = num_labels
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
output = super().to_dict()
|
output = super().to_dict()
|
||||||
@ -443,7 +439,7 @@ class ModernBertDecoderAttention(nn.Module):
|
|||||||
qkv = self.Wqkv(hidden_states)
|
qkv = self.Wqkv(hidden_states)
|
||||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
# Create position_ids if None
|
# Create position_ids that respect padding tokens if not provided
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -451,9 +447,18 @@ class ModernBertDecoderAttention(nn.Module):
|
|||||||
cache_length = past_key_value[0].shape[-2]
|
cache_length = past_key_value[0].shape[-2]
|
||||||
position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device)
|
position_ids = torch.arange(cache_length, cache_length + seq_len, dtype=torch.long, device=device)
|
||||||
else:
|
else:
|
||||||
# For initial forward pass, start from 0
|
# For initial forward pass, create position_ids that respect padding
|
||||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
if attention_mask is not None:
|
||||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
# Create cumulative sum of attention_mask to get proper positions
|
||||||
|
# This ensures padding tokens don't increment position
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 0)
|
||||||
|
else:
|
||||||
|
# Fallback: sequential positions
|
||||||
|
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
if position_ids.dim() == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -1153,7 +1158,10 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.model = ModernBertDecoderModel(config)
|
self.model = ModernBertDecoderModel(config)
|
||||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -1211,7 +1219,8 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
logits = self.score(hidden_states)
|
hidden_states = self.drop(self.head(hidden_states))
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
batch_size, sequence_length = input_ids.shape[:2]
|
batch_size, sequence_length = input_ids.shape[:2]
|
||||||
|
@ -22,6 +22,7 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
from ...test_modeling_common import _config_zero_init
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -170,3 +171,19 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
|
|||||||
# Check that loss is computed
|
# Check that loss is computed
|
||||||
self.assertIsNotNone(outputs_with_loss.loss)
|
self.assertIsNotNone(outputs_with_loss.loss)
|
||||||
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
|
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
# The classifier.weight from ModernBertDecoderForSequenceClassification and ModernBertDecoderForCausalLM
|
||||||
|
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
||||||
|
if param.requires_grad and not (name == "classifier.weight" or name == "head.weight"):
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
@ -2974,6 +2974,8 @@ class ModelTesterMixin:
|
|||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
"TimmWrapperForImageClassification",
|
"TimmWrapperForImageClassification",
|
||||||
"ModernBertForQuestionAnswering",
|
"ModernBertForQuestionAnswering",
|
||||||
|
"ModernBertDecoderForSequenceClassification",
|
||||||
|
"ModernBertDecoderForCausalLM",
|
||||||
]
|
]
|
||||||
special_param_names = [
|
special_param_names = [
|
||||||
r"^bit\.",
|
r"^bit\.",
|
||||||
|
@ -279,6 +279,7 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
"max_position_embeddings",
|
"max_position_embeddings",
|
||||||
"mlp_bias",
|
"mlp_bias",
|
||||||
"mlp_dropout",
|
"mlp_dropout",
|
||||||
|
"classifier_activation",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user