mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add TFDebertaV2ForMultipleChoice (#25932)
* Add TFDebertaV2ForMultipleChoice * Import newer model in main init * Fix import issues * Fix copies * Add doc * Fix tests * Fix copies * Fix docstring
This commit is contained in:
parent
da1af21dbb
commit
1110b565d6
@ -152,3 +152,8 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
|
||||
|
||||
[[autodoc]] TFDebertaV2ForQuestionAnswering
|
||||
- call
|
||||
|
||||
## TFDebertaV2ForMultipleChoice
|
||||
|
||||
[[autodoc]] TFDebertaV2ForMultipleChoice
|
||||
- call
|
||||
|
@ -3360,6 +3360,7 @@ else:
|
||||
[
|
||||
"TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFDebertaV2ForMaskedLM",
|
||||
"TFDebertaV2ForMultipleChoice",
|
||||
"TFDebertaV2ForQuestionAnswering",
|
||||
"TFDebertaV2ForSequenceClassification",
|
||||
"TFDebertaV2ForTokenClassification",
|
||||
@ -6969,6 +6970,7 @@ if TYPE_CHECKING:
|
||||
from .models.deberta_v2 import (
|
||||
TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDebertaV2ForMaskedLM,
|
||||
TFDebertaV2ForMultipleChoice,
|
||||
TFDebertaV2ForQuestionAnswering,
|
||||
TFDebertaV2ForSequenceClassification,
|
||||
TFDebertaV2ForTokenClassification,
|
||||
|
@ -409,6 +409,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
("bert", "TFBertForMultipleChoice"),
|
||||
("camembert", "TFCamembertForMultipleChoice"),
|
||||
("convbert", "TFConvBertForMultipleChoice"),
|
||||
("deberta-v2", "TFDebertaV2ForMultipleChoice"),
|
||||
("distilbert", "TFDistilBertForMultipleChoice"),
|
||||
("electra", "TFElectraForMultipleChoice"),
|
||||
("flaubert", "TFFlaubertForMultipleChoice"),
|
||||
|
@ -46,6 +46,7 @@ else:
|
||||
"TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFDebertaV2ForMaskedLM",
|
||||
"TFDebertaV2ForQuestionAnswering",
|
||||
"TFDebertaV2ForMultipleChoice",
|
||||
"TFDebertaV2ForSequenceClassification",
|
||||
"TFDebertaV2ForTokenClassification",
|
||||
"TFDebertaV2Model",
|
||||
@ -95,6 +96,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_tf_deberta_v2 import (
|
||||
TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDebertaV2ForMaskedLM,
|
||||
TFDebertaV2ForMultipleChoice,
|
||||
TFDebertaV2ForQuestionAnswering,
|
||||
TFDebertaV2ForSequenceClassification,
|
||||
TFDebertaV2ForTokenClassification,
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 DeBERTa-v2 model."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
@ -26,6 +25,7 @@ from ...activations_tf import get_tf_activation
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFMaskedLMOutput,
|
||||
TFMultipleChoiceModelOutput,
|
||||
TFQuestionAnsweringModelOutput,
|
||||
TFSequenceClassifierOutput,
|
||||
TFTokenClassifierOutput,
|
||||
@ -33,6 +33,7 @@ from ...modeling_tf_outputs import (
|
||||
from ...modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFModelInputType,
|
||||
TFMultipleChoiceLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
@ -47,7 +48,6 @@ from .configuration_deberta_v2 import DebertaV2Config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "DebertaV2Config"
|
||||
_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge"
|
||||
|
||||
@ -1529,3 +1529,102 @@ class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsw
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||
softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
DEBERTA_START_DOCSTRING,
|
||||
)
|
||||
class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
# _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
|
||||
# _keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
self.pooler = TFDebertaV2ContextPooler(config, name="pooler")
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFMultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
||||
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
||||
"""
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
outputs = self.deberta(
|
||||
input_ids=flat_input_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
position_ids=flat_position_ids,
|
||||
inputs_embeds=flat_inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
pooled_output = self.pooler(sequence_output, training=training)
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
||||
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFMultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -974,6 +974,13 @@ class TFDebertaV2ForMaskedLM(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaV2ForMultipleChoice(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaV2ForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
@ -31,6 +31,7 @@ if is_tf_available():
|
||||
|
||||
from transformers import (
|
||||
TFDebertaV2ForMaskedLM,
|
||||
TFDebertaV2ForMultipleChoice,
|
||||
TFDebertaV2ForQuestionAnswering,
|
||||
TFDebertaV2ForSequenceClassification,
|
||||
TFDebertaV2ForTokenClassification,
|
||||
@ -196,6 +197,22 @@ class TFDebertaV2ModelTester:
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFDebertaV2ForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -218,6 +235,7 @@ class TFDebertaModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
TFDebertaV2Model,
|
||||
TFDebertaV2ForMaskedLM,
|
||||
TFDebertaV2ForQuestionAnswering,
|
||||
TFDebertaV2ForMultipleChoice,
|
||||
TFDebertaV2ForSequenceClassification,
|
||||
TFDebertaV2ForTokenClassification,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user