mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Added type hints for Pytorch Marian calls (#16200)
* Added type hinting for forward functions in pytorch marian * typo correction * Removed type hints on functions from BART per Suraj Patil request * fix import pb * fix typo * corrected tuple call * ran black * after fix-copies Some optional tags on primitives were removed, past_key_values in MarianForCausalLM changed from Tuple of Tuple to List * Fixing copies to roformer and pegasus Co-authored-by: Clementine Fourrier <cfourrie@inria.fr> Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
parent
a2379b9257
commit
d49f8d3189
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter):
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
"""
|
||||
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
||||
the 2nd half of the vector. [dim // 2:]
|
||||
@ -131,7 +131,7 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
@ -477,7 +477,7 @@ class MarianPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
@ -665,9 +665,7 @@ class MarianEncoder(MarianPreTrainedModel):
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
self.embed_positions = MarianSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
|
||||
@ -683,14 +681,14 @@ class MarianEncoder(MarianPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -833,9 +831,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
self.embed_positions = MarianSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
|
||||
@ -870,19 +866,19 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -1082,8 +1078,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Marian Model outputting raw hidden-states without any specific head on top.",
|
||||
MARIAN_START_DOCSTRING,
|
||||
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
|
||||
)
|
||||
class MarianModel(MarianPreTrainedModel):
|
||||
def __init__(self, config: MarianConfig):
|
||||
@ -1143,7 +1138,7 @@ class MarianModel(MarianPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def resize_decoder_token_embeddings(self, new_num_tokens):
|
||||
def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
if self.config.share_encoder_decoder_embeddings:
|
||||
raise ValueError(
|
||||
"`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
|
||||
@ -1171,22 +1166,22 @@ class MarianModel(MarianPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Seq2SeqModelOutput:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
@ -1279,10 +1274,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
r"embed_positions",
|
||||
]
|
||||
|
||||
_keys_to_ignore_on_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
|
||||
|
||||
def __init__(self, config: MarianConfig):
|
||||
super().__init__(config)
|
||||
@ -1309,7 +1301,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
@ -1370,7 +1362,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
def set_output_embeddings(self, new_embeddings: nn.Embedding):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
@ -1400,23 +1392,23 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Seq2SeqLMOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
@ -1479,16 +1471,16 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
decoder_input_ids: torch.LongTensor,
|
||||
past: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter):
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
"""
|
||||
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
||||
the 2nd half of the vector. [dim // 2:]
|
||||
@ -131,7 +131,7 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
|
@ -73,12 +73,12 @@ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter):
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
"""
|
||||
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
||||
the 2nd half of the vector. [dim // 2:]
|
||||
@ -95,7 +95,7 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
|
Loading…
Reference in New Issue
Block a user