mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add type hints for PyTorch UniSpeech, MPNet and Nystromformer (#19039)
* added type hints pytorch unispeech * added type hints pytorch MPNet * added type hints nystromformer * resolved copy inconsistencies * make fix-copies Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
parent
658010c739
commit
5e636eee4a
@ -563,11 +563,11 @@ class Data2VecAudioEncoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
hidden_states: torch.tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
@ -618,7 +618,12 @@ class HubertEncoderLayerStableLayerNorm(nn.Module):
|
||||
self.feed_forward = HubertFeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
@ -649,11 +654,11 @@ class HubertEncoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
hidden_states: torch.tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
@ -323,12 +323,12 @@ class MPNetEncoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=False,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
position_bias = self.compute_position_bias(hidden_states)
|
||||
|
@ -354,12 +354,12 @@ class NystromformerEncoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
@ -655,7 +655,12 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
|
||||
self.feed_forward = UniSpeechFeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
@ -686,11 +691,11 @@ class UniSpeechEncoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
hidden_states: torch.tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
@ -669,7 +669,12 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
|
||||
self.feed_forward = UniSpeechSatFeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
@ -700,11 +705,11 @@ class UniSpeechSatEncoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
hidden_states: torch.tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
@ -704,7 +704,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
@ -734,11 +739,11 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
hidden_states: torch.tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
Loading…
Reference in New Issue
Block a user