mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Add type hints transfoxl (#16267)
* Add type hint for pt transfo_xl model * Add type hint for tf transfo_xl model
This commit is contained in:
parent
2afe9cd279
commit
460f36d352
@ -18,8 +18,9 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...file_utils import (
|
||||
@ -29,6 +30,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
get_initializer,
|
||||
@ -1077,17 +1079,17 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
mems=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
mems: Optional[List[tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple, TFTransfoXLSequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
||||
|
@ -19,7 +19,7 @@
|
||||
"""
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -1215,15 +1215,15 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
mems=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
mems: Optional[List[torch.FloatTensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, TransfoXLSequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
|
Loading…
Reference in New Issue
Block a user