mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Support QuestionAnswering Module for ModernBert based models. (#35566)
* push ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * update __init__ loading * set imports for ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * remove debugging logs * update init_weights method * remove custom initialization for ModernBertForQuestionAnswering * apply make fix-copies * apply make style * apply make fix-copies * append ModernBertForQuestionAnswering to the pipeline supported models * remove unused file * remove invalid autoload value * update en/model_doc/modernbert.md * apply make fixup command * make fixup * Update dummies * update usage tips for ModernBertForQuestionAnswering * update usage tips for ModernBertForQuestionAnswering * add init * add lint * add consistency * update init test * change text to trigger stuck text * use self.loss_function instead of custom loss By @Cyrilvallez Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * Update modeling_modernbert.py make comparable commit to even it out * Match whitespace * whitespace --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Orion Weller <wellerorion@gmail.com> Co-authored-by: Orion Weller <31665361+orionw@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
parent
5b08db8844
commit
49b5ab6a27
@ -60,6 +60,9 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- [`ModernBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
|
||||
|
||||
## ModernBertConfig
|
||||
|
||||
@ -88,5 +91,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] ModernBertForTokenClassification
|
||||
- forward
|
||||
|
||||
## ModernBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] ModernBertForQuestionAnswering
|
||||
- forward
|
||||
|
||||
### Usage tips
|
||||
|
||||
The ModernBert model can be fine-tuned using the HuggingFace Transformers library with its [official script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py) for question-answering tasks.
|
||||
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
@ -3047,6 +3047,7 @@ else:
|
||||
_import_structure["models.modernbert"].extend(
|
||||
[
|
||||
"ModernBertForMaskedLM",
|
||||
"ModernBertForQuestionAnswering",
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertModel",
|
||||
@ -7967,6 +7968,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.modernbert import (
|
||||
ModernBertForMaskedLM,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertModel,
|
||||
|
@ -1138,6 +1138,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("mistral", "MistralForQuestionAnswering"),
|
||||
("mixtral", "MixtralForQuestionAnswering"),
|
||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||
("modernbert", "ModernBertForQuestionAnswering"),
|
||||
("mpnet", "MPNetForQuestionAnswering"),
|
||||
("mpt", "MptForQuestionAnswering"),
|
||||
("mra", "MraForQuestionAnswering"),
|
||||
|
@ -30,7 +30,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@ -650,7 +656,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
init_weight(module.dense, stds["out"])
|
||||
elif isinstance(module, ModernBertForMaskedLM):
|
||||
init_weight(module.decoder, stds["out"])
|
||||
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||
):
|
||||
init_weight(module.classifier, stds["final_out"])
|
||||
|
||||
@classmethod
|
||||
@ -1384,10 +1393,98 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
|
||||
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
MODERNBERT_START_DOCSTRING,
|
||||
)
|
||||
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = ModernBertModel(config)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = outputs[0]
|
||||
|
||||
last_hidden_state = self.head(last_hidden_state)
|
||||
last_hidden_state = self.drop(last_hidden_state)
|
||||
logits = self.classifier(last_hidden_state)
|
||||
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModernBertModel",
|
||||
"ModernBertPreTrainedModel",
|
||||
"ModernBertForMaskedLM",
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
]
|
||||
|
@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
@ -825,7 +826,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
init_weight(module.dense, stds["out"])
|
||||
elif isinstance(module, ModernBertForMaskedLM):
|
||||
init_weight(module.decoder, stds["out"])
|
||||
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||
):
|
||||
init_weight(module.classifier, stds["final_out"])
|
||||
|
||||
@classmethod
|
||||
@ -1487,6 +1491,93 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
|
||||
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
MODERNBERT_START_DOCSTRING,
|
||||
)
|
||||
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = ModernBertModel(config)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = outputs[0]
|
||||
|
||||
last_hidden_state = self.head(last_hidden_state)
|
||||
last_hidden_state = self.drop(last_hidden_state)
|
||||
logits = self.classifier(last_hidden_state)
|
||||
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModernBertConfig",
|
||||
"ModernBertModel",
|
||||
@ -1494,4 +1585,5 @@ __all__ = [
|
||||
"ModernBertForMaskedLM",
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
]
|
||||
|
@ -6692,6 +6692,13 @@ class ModernBertForMaskedLM(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ModernBertForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ModernBertForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -40,6 +40,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
ModernBertForMaskedLM,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertModel,
|
||||
@ -224,6 +225,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
ModernBertForMaskedLM,
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@ -235,6 +237,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
"text-classification": ModernBertForSequenceClassification,
|
||||
"token-classification": ModernBertForTokenClassification,
|
||||
"zero-shot": ModernBertForSequenceClassification,
|
||||
"question-answering": ModernBertForQuestionAnswering,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@ -289,7 +292,12 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
||||
if param.requires_grad and not (
|
||||
name == "classifier.weight"
|
||||
and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification]
|
||||
and model_class
|
||||
in [
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
]
|
||||
):
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
|
@ -3086,6 +3086,7 @@ class ModelTesterMixin:
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"TimmWrapperForImageClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
]
|
||||
special_param_names = [
|
||||
r"^bit\.",
|
||||
|
Loading…
Reference in New Issue
Block a user