mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Adding type hints for Distilbert (#16090)
* Distillbert type - squash * Update src/transformers/models/distilbert/modeling_distilbert.py Undo cleanup Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/models/distilbert/modeling_distilbert.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/models/distilbert/modeling_distilbert.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/models/distilbert/modeling_distilbert.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Remove type Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
0b8b06185d
commit
5bdf3313ef
@ -19,6 +19,7 @@
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -26,6 +27,8 @@ from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from ...activations import get_activation
|
||||
from ...deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...file_utils import (
|
||||
@ -72,7 +75,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
@ -83,7 +86,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
|
||||
|
||||
|
||||
def _create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
|
||||
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
|
||||
out.requires_grad = False
|
||||
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
|
||||
@ -92,7 +95,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
|
||||
|
||||
class Embeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
|
||||
@ -108,7 +111,7 @@ class Embeddings(nn.Module):
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
|
||||
def forward(self, input_ids):
|
||||
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
|
||||
@ -137,7 +140,7 @@ class Embeddings(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = config.n_heads
|
||||
@ -151,9 +154,9 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||
|
||||
self.pruned_heads = set()
|
||||
self.pruned_heads: Set[int] = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
def prune_heads(self, heads: List[int]):
|
||||
attention_head_size = self.dim // self.n_heads
|
||||
if len(heads) == 0:
|
||||
return
|
||||
@ -168,7 +171,15 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
self.dim = attention_head_size * self.n_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
query: torch.tensor(bs, seq_length, dim)
|
||||
@ -189,11 +200,11 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
|
||||
mask_reshp = (bs, 1, 1, k_length)
|
||||
|
||||
def shape(x):
|
||||
def shape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""separate heads"""
|
||||
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||
|
||||
def unshape(x):
|
||||
def unshape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""group heads"""
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
||||
|
||||
@ -224,7 +235,7 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
@ -233,10 +244,10 @@ class FFN(nn.Module):
|
||||
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
|
||||
self.activation = get_activation(config.activation)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
|
||||
|
||||
def ff_chunk(self, input):
|
||||
def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.lin1(input)
|
||||
x = self.activation(x)
|
||||
x = self.lin2(x)
|
||||
@ -245,7 +256,7 @@ class FFN(nn.Module):
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
assert config.dim % config.n_heads == 0
|
||||
@ -256,7 +267,13 @@ class TransformerBlock(nn.Module):
|
||||
self.ffn = FFN(config)
|
||||
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
|
||||
|
||||
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
x: torch.tensor(bs, seq_length, dim)
|
||||
@ -284,7 +301,7 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
# Feed Forward Network
|
||||
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
|
||||
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
||||
ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
||||
|
||||
output = (ffn_output,)
|
||||
if output_attentions:
|
||||
@ -293,14 +310,20 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.n_layers = config.n_layers
|
||||
self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
|
||||
|
||||
def forward(
|
||||
self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None
|
||||
): # docstyle-ignore
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore
|
||||
"""
|
||||
Parameters:
|
||||
x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
|
||||
@ -357,7 +380,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
def _init_weights(self, module):
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
@ -432,7 +455,7 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertModel(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.embeddings = Embeddings(config) # Embeddings
|
||||
@ -489,13 +512,13 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
# move position_embeddings to correct device
|
||||
self.embeddings.position_embeddings.to(self.device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
@ -512,14 +535,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -560,7 +583,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.activation = get_activation(config.activation)
|
||||
@ -595,10 +618,10 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
"""
|
||||
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
return self.vocab_projector
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
def set_output_embeddings(self, new_embeddings: nn.Module):
|
||||
self.vocab_projector = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
|
||||
@ -610,15 +633,15 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
@ -666,7 +689,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.config = config
|
||||
@ -708,15 +731,15 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@ -784,7 +807,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.distilbert = DistilBertModel(config)
|
||||
@ -824,16 +847,16 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
@ -901,7 +924,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
@ -941,15 +964,15 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[TokenClassifierOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
@ -996,7 +1019,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.distilbert = DistilBertModel(config)
|
||||
@ -1033,15 +1056,15 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
|
Loading…
Reference in New Issue
Block a user