mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add type hints for TF MPNet models (#19089)
* Added type hints for TFMPNetModel * Added type hints for TFMPNetForMaskedLM * Added type hints for TFMPNetForSequenceClassification * Added type hints for TFMPNetForMultipleChoice * Added type hints for TFMPNetForTokenClassification * Added Type hints for TFMPNetForQuestionAnswering
This commit is contained in:
parent
1bbad7a2da
commit
fe5e7cea4a
@ -18,7 +18,9 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
@ -33,6 +35,7 @@ from ...modeling_tf_outputs import (
|
||||
)
|
||||
from ...modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFModelInputType,
|
||||
TFMultipleChoiceLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
@ -681,16 +684,16 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
):
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||
outputs = self.mpnet(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -796,17 +799,17 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
@ -901,17 +904,17 @@ class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassif
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@ -991,17 +994,17 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
||||
@ -1102,17 +1105,17 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
@ -1184,19 +1187,19 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.array, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
start_positions: Optional[tf.Tensor] = None,
|
||||
end_positions: Optional[tf.Tensor] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
|
Loading…
Reference in New Issue
Block a user