Type hints MCTCT (#19618)

* add type hints to mctct

* run auto style corrections

* change torch.bool to bool#

* Update src/transformers/models/mctct/modeling_mctct.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Remove optional tags for attention_mask and head_mask'

* fix optional tags'

* Update src/transformers/models/mctct/modeling_mctct.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Ryan Chan 2022-10-17 14:15:21 +01:00 committed by GitHub
parent 8aad4363d8
commit d7754c43d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,7 @@
import math
import random
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -566,13 +566,13 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
def forward(
self,
input_features,
attention_mask,
head_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
input_features: torch.Tensor,
attention_mask: torch.Tensor,
head_mask: torch.Tensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -680,13 +680,13 @@ class MCTCTModel(MCTCTPreTrainedModel):
)
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -751,14 +751,14 @@ class MCTCTForCTC(MCTCTPreTrainedModel):
)
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
@ -783,7 +783,6 @@ class MCTCTForCTC(MCTCTPreTrainedModel):
loss = None
if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")