Add type hints for several pytorch models (batch-2) (#25557)

* Add missing type hint to cpmant

* Add type hints to decision_transformer model

* Add type hints to deformable_detr models

* Add type hints to detr models

* Add type hints to deta models

* Add type hints to dpr models

* Update attention mask type hint

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update remaining attention masks type hints

* Update docstrings' type hints related to attention masks

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
David Reguera 2023-08-28 14:58:23 +02:00 committed by GitHub
parent de139702a1
commit cb91ec67b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 103 deletions

View File

@ -1126,7 +1126,7 @@ CONDITIONAL_DETR_INPUTS_DOCSTRING = r"""
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*): decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries. Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
@ -1872,7 +1872,7 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.LongTensor] = None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[torch.FloatTensor] = None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,

View File

@ -653,7 +653,7 @@ class CpmAntModel(CpmAntPreTrainedModel):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -787,7 +787,7 @@ DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
The returns for each state in the trajectory The returns for each state in the trajectory
timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`): timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
The timestep for each step in the trajectory The timestep for each step in the trajectory
attention_mask (`torch.LongTensor` of shape `(batch_size, episode_length)`): attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
Masking, used to mask the actions when performing autoregressive prediction Masking, used to mask the actions when performing autoregressive prediction
""" """
@ -830,16 +830,16 @@ class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
@replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
states=None, states: Optional[torch.FloatTensor] = None,
actions=None, actions: Optional[torch.FloatTensor] = None,
rewards=None, rewards: Optional[torch.FloatTensor] = None,
returns_to_go=None, returns_to_go: Optional[torch.FloatTensor] = None,
timesteps=None, timesteps: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
) -> Union[Tuple, DecisionTransformerOutput]: ) -> Union[Tuple[torch.FloatTensor], DecisionTransformerOutput]:
r""" r"""
Returns: Returns:

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -1123,7 +1123,7 @@ DEFORMABLE_DETR_INPUTS_DOCSTRING = r"""
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*): decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries. Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
@ -1625,16 +1625,16 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
@replace_return_docstrings(output_type=DeformableDetrModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DeformableDetrModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DeformableDetrModelOutput]:
r""" r"""
Returns: Returns:
@ -1885,17 +1885,17 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
@replace_return_docstrings(output_type=DeformableDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DeformableDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[List[dict]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DeformableDetrObjectDetectionOutput]:
r""" r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*): labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -1013,7 +1013,7 @@ DETA_INPUTS_DOCSTRING = r"""
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*): decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries. Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
@ -1533,16 +1533,16 @@ class DetaModel(DetaPreTrainedModel):
@replace_return_docstrings(output_type=DetaModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DetaModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DetaModelOutput]:
r""" r"""
Returns: Returns:
@ -1838,17 +1838,17 @@ class DetaForObjectDetection(DetaPreTrainedModel):
@replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[List[dict]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DetaObjectDetectionOutput]:
r""" r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*): labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the

View File

@ -17,7 +17,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -881,7 +881,7 @@ DETR_INPUTS_DOCSTRING = r"""
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*): decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries. Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
@ -1245,16 +1245,16 @@ class DetrModel(DetrPreTrainedModel):
@replace_return_docstrings(output_type=DetrModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DetrModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DetrModelOutput]:
r""" r"""
Returns: Returns:
@ -1405,17 +1405,17 @@ class DetrForObjectDetection(DetrPreTrainedModel):
@replace_return_docstrings(output_type=DetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[List[dict]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
r""" r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*): labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
@ -1575,17 +1575,17 @@ class DetrForSegmentation(DetrPreTrainedModel):
@replace_return_docstrings(output_type=DetrSegmentationOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=DetrSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[List[dict]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], DetrSegmentationOutput]:
r""" r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*): labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each

View File

@ -454,9 +454,9 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]: ) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
@ -535,9 +535,9 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]: ) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
@ -616,9 +616,9 @@ class DPRReader(DPRPretrainedReader):
input_ids: Optional[Tensor] = None, input_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None, output_attentions: Optional[bool] = None,
output_hidden_states: bool = None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]: ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return: