Add Type Hints to modeling_utils.py Closes #3911 (#3948)

* Add Type Hints to modeling_utils.py Closes #3911

Add Type Hints to methods in `modeling_utils.py`

Note: The coverage isn't 100%. Mostly skipped internal methods.

* Reformat according to `black` and `isort`

* Use typing.Iterable instead of Sequence

* Parameterize Iterable by its generic type

* Use typing.Optional when None is the default value

* Adhere to style guideline

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Bijay Gurung 2020-05-23 04:55:22 +05:45 committed by GitHub
parent 996f393a86
commit e19b978151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,7 @@
import inspect
import logging
import os
from typing import Callable, List, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple
import torch
from torch import Tensor, device, dtype, nn
@ -164,7 +164,7 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
Arguments:
@ -208,7 +208,7 @@ class ModuleUtilsMixin:
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -302,7 +302,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
raise NotImplementedError
def set_input_embeddings(self, value):
def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings
@ -354,7 +354,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens=None):
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
@ -387,18 +387,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
def _get_resized_embeddings(
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding:
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
Args:
old_embeddings: ``torch.nn.Embedding``
Old embeddings to be resized.
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``torch.nn.Embeddings``
Return: ``torch.nn.Embedding``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
if new_num_tokens is None:
@ -433,7 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Tie weights if needed
self.tie_weights()
def prune_heads(self, heads_to_prune):
def prune_heads(self, heads_to_prune: Dict):
""" Prunes heads of the base model.
Arguments:
@ -801,28 +805,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
@torch.no_grad()
def generate(
self,
input_ids=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_specific_kwargs
):
) -> torch.LongTensor:
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from `Facebook's XLM beam search code`_.
@ -1606,7 +1610,7 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
@ -1642,7 +1646,13 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
return banned_tokens
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
def top_k_top_p_filtering(
logits: Tensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> Tensor:
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)