Support QuestionAnswering Module for ModernBert based models. (#35566)

* push ModernBertForQuestionAnswering

* update ModernBertForQuestionAnswering

* update __init__ loading

* set imports for ModernBertForQuestionAnswering

* update ModernBertForQuestionAnswering

* remove debugging logs

* update init_weights method

* remove custom initialization for ModernBertForQuestionAnswering

* apply make fix-copies

* apply make style

* apply make fix-copies

* append ModernBertForQuestionAnswering to the pipeline supported models

* remove unused file

* remove invalid autoload value

* update en/model_doc/modernbert.md

* apply make fixup command

* make fixup

* Update dummies

* update usage tips for ModernBertForQuestionAnswering

* update usage tips for ModernBertForQuestionAnswering

* add init

* add lint

* add consistency

* update init test

* change text to trigger stuck text

* use self.loss_function instead of custom loss

By @Cyrilvallez

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* Update modeling_modernbert.py

make comparable commit to even it out

* Match whitespace

* whitespace

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Orion Weller <wellerorion@gmail.com>
Co-authored-by: Orion Weller <31665361+orionw@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Abu Bakr Soliman 2025-03-26 22:24:18 +02:00 committed by GitHub
parent 5b08db8844
commit 49b5ab6a27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 225 additions and 4 deletions

View File

@ -60,6 +60,9 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
- [Masked language modeling task guide](../tasks/masked_language_modeling)
<PipelineTag pipeline="question-answering"/>
- [`ModernBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
## ModernBertConfig
@ -88,5 +91,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] ModernBertForTokenClassification
- forward
## ModernBertForQuestionAnswering
[[autodoc]] ModernBertForQuestionAnswering
- forward
### Usage tips
The ModernBert model can be fine-tuned using the HuggingFace Transformers library with its [official script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py) for question-answering tasks.
</pt>
</frameworkcontent>

View File

@ -3047,6 +3047,7 @@ else:
_import_structure["models.modernbert"].extend(
[
"ModernBertForMaskedLM",
"ModernBertForQuestionAnswering",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"ModernBertModel",
@ -7967,6 +7968,7 @@ if TYPE_CHECKING:
)
from .models.modernbert import (
ModernBertForMaskedLM,
ModernBertForQuestionAnswering,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,

View File

@ -1138,6 +1138,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("mistral", "MistralForQuestionAnswering"),
("mixtral", "MixtralForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("modernbert", "ModernBertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("mpt", "MptForQuestionAnswering"),
("mra", "MraForQuestionAnswering"),

View File

@ -30,7 +30,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
@ -650,7 +656,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertForMaskedLM):
init_weight(module.decoder, stds["out"])
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
elif isinstance(
module,
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
):
init_weight(module.classifier, stds["final_out"])
@classmethod
@ -1384,10 +1393,98 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
)
@add_start_docstrings(
"""
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MODERNBERT_START_DOCSTRING,
)
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
outputs = self.model(
input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier(last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"ModernBertModel",
"ModernBertPreTrainedModel",
"ModernBertForMaskedLM",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"ModernBertForQuestionAnswering",
]

View File

@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
@ -825,7 +826,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertForMaskedLM):
init_weight(module.decoder, stds["out"])
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
elif isinstance(
module,
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
):
init_weight(module.classifier, stds["final_out"])
@classmethod
@ -1487,6 +1491,93 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
)
@add_start_docstrings(
"""
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MODERNBERT_START_DOCSTRING,
)
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
outputs = self.model(
input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier(last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"ModernBertConfig",
"ModernBertModel",
@ -1494,4 +1585,5 @@ __all__ = [
"ModernBertForMaskedLM",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"ModernBertForQuestionAnswering",
]

View File

@ -6692,6 +6692,13 @@ class ModernBertForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class ModernBertForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ModernBertForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -40,6 +40,7 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
ModernBertForMaskedLM,
ModernBertForQuestionAnswering,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,
@ -224,6 +225,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
ModernBertForMaskedLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
)
if is_torch_available()
else ()
@ -235,6 +237,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
"text-classification": ModernBertForSequenceClassification,
"token-classification": ModernBertForTokenClassification,
"zero-shot": ModernBertForSequenceClassification,
"question-answering": ModernBertForQuestionAnswering,
}
if is_torch_available()
else {}
@ -289,7 +292,12 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
if param.requires_grad and not (
name == "classifier.weight"
and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification]
and model_class
in [
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
]
):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),

View File

@ -3086,6 +3086,7 @@ class ModelTesterMixin:
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"TimmWrapperForImageClassification",
"ModernBertForQuestionAnswering",
]
special_param_names = [
r"^bit\.",