mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Add Type Hints to modeling_utils.py Closes #3911 Add Type Hints to methods in `modeling_utils.py` Note: The coverage isn't 100%. Mostly skipped internal methods. * Reformat according to `black` and `isort` * Use typing.Iterable instead of Sequence * Parameterize Iterable by its generic type * Use typing.Optional when None is the default value * Adhere to style guideline * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
996f393a86
commit
e19b978151
@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
@ -164,7 +164,7 @@ class ModuleUtilsMixin:
|
||||
|
||||
return encoder_extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
|
||||
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
@ -208,7 +208,7 @@ class ModuleUtilsMixin:
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
|
||||
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
|
||||
"""
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -302,7 +302,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
"""
|
||||
Set model's input embeddings
|
||||
|
||||
@ -354,7 +354,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
@ -387,18 +387,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
def _get_resized_embeddings(
|
||||
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
|
||||
) -> torch.nn.Embedding:
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
old_embeddings: ``torch.nn.Embedding``
|
||||
Old embeddings to be resized.
|
||||
new_num_tokens: (`optional`) int
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Return: ``torch.nn.Embedding``
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
@ -433,7 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# Tie weights if needed
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
def prune_heads(self, heads_to_prune: Dict):
|
||||
""" Prunes heads of the base model.
|
||||
|
||||
Arguments:
|
||||
@ -801,28 +805,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids=None,
|
||||
max_length=None,
|
||||
min_length=None,
|
||||
do_sample=None,
|
||||
early_stopping=None,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty=None,
|
||||
bad_words_ids=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
num_return_sequences=None,
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
bad_words_ids: Optional[Iterable[int]] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**model_specific_kwargs
|
||||
):
|
||||
) -> torch.LongTensor:
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
||||
|
||||
Adapted in part from `Facebook's XLM beam search code`_.
|
||||
@ -1606,7 +1610,7 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
|
||||
return banned_tokens
|
||||
|
||||
|
||||
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
||||
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
|
||||
banned_tokens = []
|
||||
|
||||
def _tokens_match(prev_tokens, tokens):
|
||||
@ -1642,7 +1646,13 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
||||
return banned_tokens
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
def top_k_top_p_filtering(
|
||||
logits: Tensor,
|
||||
top_k: int = 0,
|
||||
top_p: float = 1.0,
|
||||
filter_value: float = -float("Inf"),
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> Tensor:
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
|
Loading…
Reference in New Issue
Block a user