Use the pooling head in TokenClassification

This commit is contained in:
Benjamin Warner 2024-12-18 11:10:27 -06:00
parent e057bc27ad
commit 99c38badd1
2 changed files with 4 additions and 4 deletions

View File

@ -1297,7 +1297,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
self.num_labels = config.num_labels
self.model = ModernBertModel(config)
self.drop = nn.Dropout(config.classifier_dropout)
self.head = ModernBertPoolingHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
@ -1348,7 +1348,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
)
last_hidden_state = outputs[0]
last_hidden_state = self.drop(last_hidden_state)
last_hidden_state = self.head(last_hidden_state, attention_mask, pool=False)
logits = self.classifier(last_hidden_state)
loss = None

View File

@ -1425,7 +1425,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
self.num_labels = config.num_labels
self.model = ModernBertModel(config)
self.drop = nn.Dropout(config.classifier_dropout)
self.head = ModernBertPoolingHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
@ -1476,7 +1476,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
)
last_hidden_state = outputs[0]
last_hidden_state = self.drop(last_hidden_state)
last_hidden_state = self.head(last_hidden_state, attention_mask, pool=False)
logits = self.classifier(last_hidden_state)
loss = None