Added type hints for Pytorch Marian calls (#16200)

* Added type hinting for forward functions in pytorch marian

* typo correction

* Removed type hints on functions from BART per Suraj Patil request

* fix import pb

* fix typo

* corrected tuple call

* ran black

* after fix-copies
Some optional tags on primitives were removed, past_key_values in MarianForCausalLM changed from Tuple of Tuple to List

* Fixing copies to roformer and pegasus

Co-authored-by: Clementine Fourrier <cfourrie@inria.fr>
Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
Clémentine Fourrier 2022-03-22 15:45:59 +01:00 committed by GitHub
parent a2379b9257
commit d49f8d3189
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 90 deletions

View File

@ -18,7 +18,7 @@
import copy
import math
import random
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
@ -131,7 +131,7 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding):
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
@ -477,7 +477,7 @@ class MarianPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
@ -665,9 +665,7 @@ class MarianEncoder(MarianPreTrainedModel):
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
config.max_position_embeddings, embed_dim, self.padding_idx
)
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
@ -683,14 +681,14 @@ class MarianEncoder(MarianPreTrainedModel):
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -833,9 +831,7 @@ class MarianDecoder(MarianPreTrainedModel):
self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.max_position_embeddings, config.d_model, self.padding_idx
)
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
@ -870,19 +866,19 @@ class MarianDecoder(MarianPreTrainedModel):
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1082,8 +1078,7 @@ class MarianDecoder(MarianPreTrainedModel):
@add_start_docstrings(
"The bare Marian Model outputting raw hidden-states without any specific head on top.",
MARIAN_START_DOCSTRING,
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
)
class MarianModel(MarianPreTrainedModel):
def __init__(self, config: MarianConfig):
@ -1143,7 +1138,7 @@ class MarianModel(MarianPreTrainedModel):
def get_decoder(self):
return self.decoder
def resize_decoder_token_embeddings(self, new_num_tokens):
def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
if self.config.share_encoder_decoder_embeddings:
raise ValueError(
"`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
@ -1171,22 +1166,22 @@ class MarianModel(MarianPreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Seq2SeqModelOutput:
r"""
Returns:
@ -1279,10 +1274,7 @@ class MarianMTModel(MarianPreTrainedModel):
r"embed_positions",
]
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
def __init__(self, config: MarianConfig):
super().__init__(config)
@ -1309,7 +1301,7 @@ class MarianMTModel(MarianPreTrainedModel):
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
def _resize_token_embeddings(self, new_num_tokens):
def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
@ -1370,7 +1362,7 @@ class MarianMTModel(MarianPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
def set_output_embeddings(self, new_embeddings: nn.Embedding):
self.lm_head = new_embeddings
def tie_weights(self):
@ -1400,23 +1392,23 @@ class MarianMTModel(MarianPreTrainedModel):
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Seq2SeqLMOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1479,16 +1471,16 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
decoder_input_ids: torch.LongTensor,
past: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
**kwargs,
) -> Dict:
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

View File

@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
@ -131,7 +131,7 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(

View File

@ -73,12 +73,12 @@ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
@ -95,7 +95,7 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(