mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
added type hints for BART model (#16270)
* added type hints for BART model * make fixup, adding imports to copied files * Adding some missing types to cookiecutter * Adding some missing types to cookiecutter * Adding some missing types to cookiecutter Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
parent
460f36d352
commit
d50f62f2de
@ -17,7 +17,7 @@ import copy
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -297,11 +297,11 @@ class BartEncoderLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_head_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
layer_head_mask: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -384,7 +384,7 @@ class BartDecoderLayer(nn.Module):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
@ -478,7 +478,7 @@ class BartClassificationHead(nn.Module):
|
||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
@ -728,14 +728,14 @@ class BartEncoder(BartPretrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -917,19 +917,19 @@ class BartDecoder(BartPretrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -1172,22 +1172,22 @@ class BartModel(BartPretrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqModelOutput]:
|
||||
|
||||
# different to other models, Bart automatically creates decoder_input_ids from
|
||||
# input_ids if no decoder_input_ids are provided
|
||||
@ -1306,23 +1306,23 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
@ -1454,22 +1454,22 @@ class BartForSequenceClassification(BartPretrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@ -1580,23 +1580,23 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
@ -1721,20 +1721,20 @@ class BartForCausalLM(BartPretrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -18,6 +18,7 @@
|
||||
import random
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
@ -39,6 +40,7 @@ from ...modeling_tf_outputs import (
|
||||
from ...modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
@ -170,7 +172,7 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -297,7 +299,13 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
|
||||
layer_head_mask: Optional[tf.Tensor],
|
||||
training: Optional[bool] = False,
|
||||
) -> tf.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -365,14 +373,14 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -663,16 +671,16 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -813,21 +821,21 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -1030,24 +1038,24 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs
|
||||
):
|
||||
) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
|
||||
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
use_cache = False
|
||||
@ -1143,24 +1151,24 @@ class TFBartModel(TFBartPretrainedModel):
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs
|
||||
):
|
||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -1248,25 +1256,25 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
|
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1573,7 +1573,7 @@ class BigBirdPegasusClassificationHead(nn.Module):
|
||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
@ -2367,22 +2367,22 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqModelOutput]:
|
||||
|
||||
# different to other models, BigBirdPegasus automatically creates decoder_input_ids from
|
||||
# input_ids if no decoder_input_ids are provided
|
||||
@ -2503,23 +2503,23 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
@ -2652,22 +2652,22 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@ -2779,23 +2779,23 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
|
@ -20,7 +20,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -1440,20 +1440,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -173,7 +173,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -288,11 +288,11 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_head_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
layer_head_mask: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -376,7 +376,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
@ -1411,20 +1411,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -18,6 +18,7 @@
|
||||
import random
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
@ -172,7 +173,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -300,7 +301,13 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
|
||||
layer_head_mask: Optional[tf.Tensor],
|
||||
training: Optional[bool] = False,
|
||||
) -> tf.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -369,14 +376,14 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
|
@ -754,7 +754,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -305,11 +305,11 @@ class MarianEncoderLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_head_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
layer_head_mask: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -393,7 +393,7 @@ class MarianDecoderLayer(nn.Module):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
@ -1573,20 +1573,20 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -212,7 +212,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -340,7 +340,13 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
|
||||
layer_head_mask: Optional[tf.Tensor],
|
||||
training: Optional[bool] = False,
|
||||
) -> tf.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -409,14 +415,14 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
|
@ -16,7 +16,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -485,7 +485,7 @@ class MBartClassificationHead(nn.Module):
|
||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
@ -1445,22 +1445,22 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@ -1572,23 +1572,23 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
@ -1715,20 +1715,20 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -172,7 +172,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1543,20 +1543,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -213,7 +213,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -302,11 +302,11 @@ class PLBartEncoderLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_head_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
layer_head_mask: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
@ -390,7 +390,7 @@ class PLBartDecoderLayer(nn.Module):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
@ -485,7 +485,7 @@ class PLBartClassificationHead(nn.Module):
|
||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
@ -699,14 +699,14 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -889,19 +889,19 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -1416,22 +1416,22 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@ -1562,20 +1562,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
|
@ -274,7 +274,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
|
@ -783,7 +783,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
training: Optional[bool] = False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
|
@ -25,7 +25,7 @@ import math
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
|
@ -1571,7 +1571,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
||||
import math
|
||||
import copy
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
Loading…
Reference in New Issue
Block a user