mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
* Add Type Hints to modeling_utils.py Closes #3911 Add Type Hints to methods in `modeling_utils.py` Note: The coverage isn't 100%. Mostly skipped internal methods. * Reformat according to `black` and `isort` * Use typing.Iterable instead of Sequence * Parameterize Iterable by its generic type * Use typing.Optional when None is the default value * Adhere to style guideline * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
996f393a86
commit
e19b978151
@ -17,7 +17,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Callable, List, Tuple
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
@ -164,7 +164,7 @@ class ModuleUtilsMixin:
|
|||||||
|
|
||||||
return encoder_extended_attention_mask
|
return encoder_extended_attention_mask
|
||||||
|
|
||||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
|
||||||
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -208,7 +208,7 @@ class ModuleUtilsMixin:
|
|||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
return extended_attention_mask
|
return extended_attention_mask
|
||||||
|
|
||||||
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
|
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
|
||||||
"""
|
"""
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@ -302,7 +302,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value: nn.Module):
|
||||||
"""
|
"""
|
||||||
Set model's input embeddings
|
Set model's input embeddings
|
||||||
|
|
||||||
@ -354,7 +354,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
||||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||||
|
|
||||||
def resize_token_embeddings(self, new_num_tokens=None):
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
|
||||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||||
|
|
||||||
@ -387,18 +387,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
self.set_input_embeddings(new_embeddings)
|
self.set_input_embeddings(new_embeddings)
|
||||||
return self.get_input_embeddings()
|
return self.get_input_embeddings()
|
||||||
|
|
||||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
def _get_resized_embeddings(
|
||||||
|
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
|
||||||
|
) -> torch.nn.Embedding:
|
||||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||||
Increasing the size will add newly initialized vectors at the end
|
Increasing the size will add newly initialized vectors at the end
|
||||||
Reducing the size will remove vectors from the end
|
Reducing the size will remove vectors from the end
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
old_embeddings: ``torch.nn.Embedding``
|
||||||
|
Old embeddings to be resized.
|
||||||
new_num_tokens: (`optional`) int
|
new_num_tokens: (`optional`) int
|
||||||
New number of tokens in the embedding matrix.
|
New number of tokens in the embedding matrix.
|
||||||
Increasing the size will add newly initialized vectors at the end
|
Increasing the size will add newly initialized vectors at the end
|
||||||
Reducing the size will remove vectors from the end
|
Reducing the size will remove vectors from the end
|
||||||
If not provided or None: return the provided token Embedding Module.
|
If not provided or None: return the provided token Embedding Module.
|
||||||
Return: ``torch.nn.Embeddings``
|
Return: ``torch.nn.Embedding``
|
||||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||||
"""
|
"""
|
||||||
if new_num_tokens is None:
|
if new_num_tokens is None:
|
||||||
@ -433,7 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# Tie weights if needed
|
# Tie weights if needed
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def prune_heads(self, heads_to_prune):
|
def prune_heads(self, heads_to_prune: Dict):
|
||||||
""" Prunes heads of the base model.
|
""" Prunes heads of the base model.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -801,28 +805,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
max_length=None,
|
max_length: Optional[int] = None,
|
||||||
min_length=None,
|
min_length: Optional[int] = None,
|
||||||
do_sample=None,
|
do_sample: Optional[bool] = None,
|
||||||
early_stopping=None,
|
early_stopping: Optional[bool] = None,
|
||||||
num_beams=None,
|
num_beams: Optional[int] = None,
|
||||||
temperature=None,
|
temperature: Optional[float] = None,
|
||||||
top_k=None,
|
top_k: Optional[int] = None,
|
||||||
top_p=None,
|
top_p: Optional[float] = None,
|
||||||
repetition_penalty=None,
|
repetition_penalty: Optional[float] = None,
|
||||||
bad_words_ids=None,
|
bad_words_ids: Optional[Iterable[int]] = None,
|
||||||
bos_token_id=None,
|
bos_token_id: Optional[int] = None,
|
||||||
pad_token_id=None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id=None,
|
eos_token_id: Optional[int] = None,
|
||||||
length_penalty=None,
|
length_penalty: Optional[float] = None,
|
||||||
no_repeat_ngram_size=None,
|
no_repeat_ngram_size: Optional[int] = None,
|
||||||
num_return_sequences=None,
|
num_return_sequences: Optional[int] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
decoder_start_token_id=None,
|
decoder_start_token_id: Optional[int] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
**model_specific_kwargs
|
**model_specific_kwargs
|
||||||
):
|
) -> torch.LongTensor:
|
||||||
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
||||||
|
|
||||||
Adapted in part from `Facebook's XLM beam search code`_.
|
Adapted in part from `Facebook's XLM beam search code`_.
|
||||||
@ -1606,7 +1610,7 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
|
|||||||
return banned_tokens
|
return banned_tokens
|
||||||
|
|
||||||
|
|
||||||
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
|
||||||
banned_tokens = []
|
banned_tokens = []
|
||||||
|
|
||||||
def _tokens_match(prev_tokens, tokens):
|
def _tokens_match(prev_tokens, tokens):
|
||||||
@ -1642,7 +1646,13 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
|||||||
return banned_tokens
|
return banned_tokens
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
def top_k_top_p_filtering(
|
||||||
|
logits: Tensor,
|
||||||
|
top_k: int = 0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
filter_value: float = -float("Inf"),
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
) -> Tensor:
|
||||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
Args:
|
Args:
|
||||||
logits: logits distribution shape (batch size, vocabulary size)
|
logits: logits distribution shape (batch size, vocabulary size)
|
||||||
|
Loading…
Reference in New Issue
Block a user