mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Pytorch type hints (#20112)
* initial commit * Update modeling_whisper.py * Fixing Tests * modeling_vision_text_dual_encoder * modeling_vision_encoder_decoder * Update modeling_vit.py * Update modeling_vit_msn.py * Update modeling_trajectory_transformer.py * style * Update modeling_time_series_transformer.py * Update modeling_time_series_transformer.py * Update modeling_segformer.py * Update modeling_plbart.py * Update modeling_dpt.py * Update modeling_deit.py * Update modeling_dpt.py * Update modeling_esm.py * Update modeling_fnet.py * Update modeling_fnet.py * Update modeling_fnet.py * Update modeling_flava.py * Update modeling_flava.py * Update modeling_layoutlmv3.py * Update modeling_levit.py
This commit is contained in:
parent
03bc6ece1b
commit
d24e84d9ed
@ -494,7 +494,7 @@ class DeiTModel(DeiTPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
@ -712,12 +712,12 @@ class DPTModel(DPTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values,
|
pixel_values: torch.FloatTensor,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -875,13 +875,13 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values,
|
pixel_values: torch.FloatTensor,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
||||||
Ground truth depth estimation maps for computing the loss.
|
Ground truth depth estimation maps for computing the loss.
|
||||||
@ -1036,13 +1036,13 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values=None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], SemanticSegmenterOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
||||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||||
|
@ -940,7 +940,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, MaskedLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||||
@ -1042,7 +1042,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, SequenceClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
@ -1138,7 +1138,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, TokenClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||||
|
@ -943,7 +943,7 @@ class FlavaImageModel(FlavaPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -1039,7 +1039,7 @@ class FlavaTextModel(FlavaPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -1142,7 +1142,7 @@ class FlavaMultimodalModel(FlavaPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
@ -548,13 +548,13 @@ class FNetModel(FNetPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
@ -848,18 +848,18 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
bbox=None,
|
bbox: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
pixel_values=None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -561,7 +561,7 @@ class LevitModel(LevitPreTrainedModel):
|
|||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
@ -630,7 +630,7 @@ class LevitForImageClassification(LevitPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||||
@ -722,7 +722,7 @@ class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
|
|||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, LevitForImageClassificationWithTeacherOutput]:
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
||||||
|
@ -1176,7 +1176,7 @@ class PLBartModel(PLBartPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
@ -1159,16 +1159,16 @@ class RealmEmbedder(RealmPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=RealmEmbedderOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=RealmEmbedderOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, RealmEmbedderOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -1241,20 +1241,20 @@ class RealmScorer(RealmPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=RealmScorerOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=RealmScorerOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
candidate_input_ids=None,
|
candidate_input_ids: Optional[torch.LongTensor] = None,
|
||||||
candidate_attention_mask=None,
|
candidate_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
candidate_token_type_ids=None,
|
candidate_token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
candidate_inputs_embeds=None,
|
candidate_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, RealmScorerOutput]:
|
||||||
r"""
|
r"""
|
||||||
candidate_input_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`):
|
candidate_input_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`):
|
||||||
Indices of candidate input sequence tokens in the vocabulary.
|
Indices of candidate input sequence tokens in the vocabulary.
|
||||||
@ -1396,19 +1396,19 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
relevance_score=None,
|
relevance_score: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
mlm_mask=None,
|
mlm_mask: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, MaskedLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
relevance_score (`torch.FloatTensor` of shape `(batch_size, num_candidates)`, *optional*):
|
relevance_score (`torch.FloatTensor` of shape `(batch_size, num_candidates)`, *optional*):
|
||||||
Relevance score derived from RealmScorer, must be specified if you want to compute the masked language
|
Relevance score derived from RealmScorer, must be specified if you want to compute the masked language
|
||||||
@ -1537,21 +1537,21 @@ class RealmReader(RealmPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=RealmReaderOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=RealmReaderOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
relevance_score=None,
|
relevance_score: Optional[torch.FloatTensor] = None,
|
||||||
block_mask=None,
|
block_mask: Optional[torch.BoolTensor] = None,
|
||||||
start_positions=None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions=None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
has_answers=None,
|
has_answers: Optional[torch.BoolTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, RealmReaderOutput]:
|
||||||
r"""
|
r"""
|
||||||
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
|
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
|
||||||
Relevance score, which must be specified if you want to compute the logits and marginal log loss.
|
Relevance score, which must be specified if you want to compute the logits and marginal log loss.
|
||||||
@ -1763,12 +1763,12 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: Optional[torch.LongTensor],
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
answer_ids=None,
|
answer_ids: Optional[torch.LongTensor] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, RealmForOpenQAOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -706,7 +706,7 @@ class SegformerDecodeHead(SegformerPreTrainedModel):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(self, encoder_hidden_states):
|
def forward(self, encoder_hidden_states: torch.FloatTensor):
|
||||||
batch_size = encoder_hidden_states[-1].shape[0]
|
batch_size = encoder_hidden_states[-1].shape[0]
|
||||||
|
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
""" Classes to support Speech-Encoder-Text-Decoder architectures"""
|
""" Classes to support Speech-Encoder-Text-Decoder architectures"""
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -443,22 +443,22 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs=None,
|
inputs: Optional[torch.FloatTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
input_values=None,
|
input_values: Optional[torch.FloatTensor] = None,
|
||||||
input_features=None,
|
input_features: Optional[torch.FloatTensor] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -1144,21 +1144,21 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features=None,
|
input_features: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -1291,22 +1291,22 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features=None,
|
input_features: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -780,20 +780,20 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor], CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
@ -1584,7 +1584,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Seq2SeqTimeSeriesModelOutput, Tuple]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -1747,7 +1747,7 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Seq2SeqTimeSeriesModelOutput, Tuple]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -478,7 +478,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], TrajectoryTransformerOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -520,19 +520,19 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values=None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
""" PyTorch VisionTextDualEncoder model."""
|
""" PyTorch VisionTextDualEncoder model."""
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -295,16 +295,16 @@ class VisionTextDualEncoderModel(PreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CLIPOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CLIPOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
pixel_values=None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
return_loss=None,
|
return_loss: Optional[bool] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], CLIPOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -541,7 +541,7 @@ class ViTModel(ViTPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = None,
|
interpolate_pos_encoding: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
@ -525,7 +525,7 @@ class ViTMSNModel(ViTMSNPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = None,
|
interpolate_pos_encoding: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -1004,20 +1004,20 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features=None,
|
input_features: Optional[torch.LongTensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -1140,21 +1140,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features=None,
|
input_features: Optional[torch.LongTensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
||||||
|
Loading…
Reference in New Issue
Block a user