Add type hints for PyTorch UniSpeech, MPNet and Nystromformer (#19039)

* added type hints pytorch unispeech

* added type hints pytorch  MPNet

* added type hints nystromformer

* resolved copy inconsistencies

* make fix-copies

Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
Partho 2022-09-16 22:29:40 +05:30 committed by GitHub
parent 658010c739
commit 5e636eee4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 41 deletions

View File

@ -563,11 +563,11 @@ class Data2VecAudioEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

View File

@ -618,7 +618,12 @@ class HubertEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
@ -649,11 +654,11 @@ class HubertEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

View File

@ -323,12 +323,12 @@ class MPNetEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
**kwargs,
):
position_bias = self.compute_position_bias(hidden_states)

View File

@ -354,12 +354,12 @@ class NystromformerEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

View File

@ -655,7 +655,12 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
@ -686,11 +691,11 @@ class UniSpeechEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

View File

@ -669,7 +669,12 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = UniSpeechSatFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
@ -700,11 +705,11 @@ class UniSpeechSatEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

View File

@ -704,7 +704,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
@ -734,11 +739,11 @@ class Wav2Vec2Encoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None