Mistral-related models for QnA (#34045)

* mistral qna start

* mixtral qna

* oops

* qwen2 qna

* qwen2moe qna

* add missing input embed methods

* add copied to all methods, can't directly from llama due to the prefix

* make top level copied from
This commit is contained in:
Anton Vlasjuk 2024-10-14 08:53:32 +02:00 committed by GitHub
parent 37ea04013b
commit 7434c0ed21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 507 additions and 4 deletions

View File

@ -208,6 +208,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MistralForTokenClassification
- forward
## MistralForQuestionAnswering
[[autodoc]] MistralForQuestionAnswering
- forward
## FlaxMistralModel
[[autodoc]] FlaxMistralModel

View File

@ -209,3 +209,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MixtralForTokenClassification
- forward
## MixtralForQuestionAnswering
[[autodoc]] MixtralForQuestionAnswering
- forward

View File

@ -85,3 +85,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Instruct` for the inferenc
[[autodoc]] Qwen2ForTokenClassification
- forward
## Qwen2ForQuestionAnswering
[[autodoc]] Qwen2ForQuestionAnswering
- forward

View File

@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf
[[autodoc]] Qwen2MoeForTokenClassification
- forward
## Qwen2MoeForQuestionAnswering
[[autodoc]] Qwen2MoeForQuestionAnswering
- forward

View File

@ -2709,6 +2709,7 @@ else:
_import_structure["models.mistral"].extend(
[
"MistralForCausalLM",
"MistralForQuestionAnswering",
"MistralForSequenceClassification",
"MistralForTokenClassification",
"MistralModel",
@ -2718,6 +2719,7 @@ else:
_import_structure["models.mixtral"].extend(
[
"MixtralForCausalLM",
"MixtralForQuestionAnswering",
"MixtralForSequenceClassification",
"MixtralForTokenClassification",
"MixtralModel",
@ -3094,6 +3096,7 @@ else:
_import_structure["models.qwen2"].extend(
[
"Qwen2ForCausalLM",
"Qwen2ForQuestionAnswering",
"Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
"Qwen2Model",
@ -3110,6 +3113,7 @@ else:
_import_structure["models.qwen2_moe"].extend(
[
"Qwen2MoeForCausalLM",
"Qwen2MoeForQuestionAnswering",
"Qwen2MoeForSequenceClassification",
"Qwen2MoeForTokenClassification",
"Qwen2MoeModel",
@ -7323,6 +7327,7 @@ if TYPE_CHECKING:
)
from .models.mistral import (
MistralForCausalLM,
MistralForQuestionAnswering,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
@ -7330,6 +7335,7 @@ if TYPE_CHECKING:
)
from .models.mixtral import (
MixtralForCausalLM,
MixtralForQuestionAnswering,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
@ -7625,6 +7631,7 @@ if TYPE_CHECKING:
)
from .models.qwen2 import (
Qwen2ForCausalLM,
Qwen2ForQuestionAnswering,
Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model,
@ -7637,6 +7644,7 @@ if TYPE_CHECKING:
)
from .models.qwen2_moe import (
Qwen2MoeForCausalLM,
Qwen2MoeForQuestionAnswering,
Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel,

View File

@ -1046,6 +1046,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("mbart", "MBartForQuestionAnswering"),
("mega", "MegaForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mistral", "MistralForQuestionAnswering"),
("mixtral", "MixtralForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("mpt", "MptForQuestionAnswering"),
@ -1057,6 +1059,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("nystromformer", "NystromformerForQuestionAnswering"),
("opt", "OPTForQuestionAnswering"),
("qdqbert", "QDQBertForQuestionAnswering"),
("qwen2", "Qwen2ForQuestionAnswering"),
("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),

View File

@ -35,6 +35,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_mistral"] = [
"MistralForCausalLM",
"MistralForQuestionAnswering",
"MistralModel",
"MistralPreTrainedModel",
"MistralForSequenceClassification",
@ -78,6 +79,7 @@ if TYPE_CHECKING:
else:
from .modeling_mistral import (
MistralForCausalLM,
MistralForQuestionAnswering,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
@ -1408,3 +1409,103 @@ class MistralForTokenClassification(MistralPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MISTRAL_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model
class MistralForQuestionAnswering(MistralPreTrainedModel):
base_model_prefix = "model"
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral
def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -33,6 +33,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_mixtral"] = [
"MixtralForCausalLM",
"MixtralForQuestionAnswering",
"MixtralModel",
"MixtralPreTrainedModel",
"MixtralForSequenceClassification",
@ -51,6 +52,7 @@ if TYPE_CHECKING:
else:
from .modeling_mixtral import (
MixtralForCausalLM,
MixtralForQuestionAnswering,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,

View File

@ -35,6 +35,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_caus
from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
@ -1644,3 +1645,103 @@ class MixtralForTokenClassification(MixtralPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL
class MixtralForQuestionAnswering(MixtralPreTrainedModel):
base_model_prefix = "model"
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral
def __init__(self, config):
super().__init__(config)
self.model = MixtralModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -42,6 +42,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_qwen2"] = [
"Qwen2ForCausalLM",
"Qwen2ForQuestionAnswering",
"Qwen2Model",
"Qwen2PreTrainedModel",
"Qwen2ForSequenceClassification",
@ -69,6 +70,7 @@ if TYPE_CHECKING:
else:
from .modeling_qwen2 import (
Qwen2ForCausalLM,
Qwen2ForQuestionAnswering,
Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model,

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
@ -1508,3 +1509,103 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
QWEN2_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2
class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
base_model_prefix = "model"
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -33,6 +33,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_qwen2_moe"] = [
"Qwen2MoeForCausalLM",
"Qwen2MoeForQuestionAnswering",
"Qwen2MoeModel",
"Qwen2MoePreTrainedModel",
"Qwen2MoeForSequenceClassification",
@ -51,6 +52,7 @@ if TYPE_CHECKING:
else:
from .modeling_qwen2_moe import (
Qwen2MoeForCausalLM,
Qwen2MoeForQuestionAnswering,
Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel,

View File

@ -35,6 +35,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
@ -1713,3 +1714,103 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Qwen2MoE Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
QWEN2MOE_START_DOCSTRING,
)
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE
class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel):
base_model_prefix = "model"
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe
def __init__(self, config):
super().__init__(config)
self.model = Qwen2MoeModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -5920,6 +5920,13 @@ class MistralForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MistralForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MistralForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
@ -5955,6 +5962,13 @@ class MixtralForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MixtralForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
@ -7406,6 +7420,13 @@ class Qwen2ForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Qwen2ForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Qwen2ForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
@ -7462,6 +7483,13 @@ class Qwen2MoeForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Qwen2MoeForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -47,6 +47,7 @@ if is_torch_available():
from transformers import (
MistralForCausalLM,
MistralForQuestionAnswering,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
@ -291,7 +292,13 @@ class MistralModelTester:
@require_torch
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
(
MistralModel,
MistralForCausalLM,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralForQuestionAnswering,
)
if is_torch_available()
else ()
)
@ -303,6 +310,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
"token-classification": MistralForTokenClassification,
"text-generation": MistralForCausalLM,
"zero-shot": MistralForSequenceClassification,
"question-answering": MistralForQuestionAnswering,
}
if is_torch_available()
else {}

View File

@ -41,6 +41,7 @@ if is_torch_available():
from transformers import (
MixtralForCausalLM,
MixtralForQuestionAnswering,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
@ -291,7 +292,13 @@ class MixtralModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
(
MixtralModel,
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralForQuestionAnswering,
)
if is_torch_available()
else ()
)
@ -303,6 +310,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
"token-classification": MixtralForTokenClassification,
"text-generation": MixtralForCausalLM,
"zero-shot": MixtralForSequenceClassification,
"question-answering": MixtralForQuestionAnswering,
}
if is_torch_available()
else {}

View File

@ -43,6 +43,7 @@ if is_torch_available():
from transformers import (
Qwen2ForCausalLM,
Qwen2ForQuestionAnswering,
Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model,
@ -300,7 +301,13 @@ class Qwen2ModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
(
Qwen2Model,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2ForQuestionAnswering,
)
if is_torch_available()
else ()
)
@ -312,6 +319,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"token-classification": Qwen2ForTokenClassification,
"text-generation": Qwen2ForCausalLM,
"zero-shot": Qwen2ForSequenceClassification,
"question-answering": Qwen2ForQuestionAnswering,
}
if is_torch_available()
else {}

View File

@ -43,6 +43,7 @@ if is_torch_available():
from transformers import (
Qwen2MoeForCausalLM,
Qwen2MoeForQuestionAnswering,
Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel,
@ -327,7 +328,13 @@ class Qwen2MoeModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
(
Qwen2MoeModel,
Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeForQuestionAnswering,
)
if is_torch_available()
else ()
)
@ -339,6 +346,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
"token-classification": Qwen2MoeForTokenClassification,
"text-generation": Qwen2MoeForCausalLM,
"zero-shot": Qwen2MoeForSequenceClassification,
"question-answering": Qwen2MoeForQuestionAnswering,
}
if is_torch_available()
else {}