mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Type hints MCTCT (#19618)
* add type hints to mctct * run auto style corrections * change torch.bool to bool# * Update src/transformers/models/mctct/modeling_mctct.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Remove optional tags for attention_mask and head_mask' * fix optional tags' * Update src/transformers/models/mctct/modeling_mctct.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
8aad4363d8
commit
d7754c43d0
@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -566,13 +566,13 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
head_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -680,13 +680,13 @@ class MCTCTModel(MCTCTPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -751,14 +751,14 @@ class MCTCTForCTC(MCTCTPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
||||
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
||||
@ -783,7 +783,6 @@ class MCTCTForCTC(MCTCTPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
if labels.max() >= self.config.vocab_size:
|
||||
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user