Adding type hints for Distilbert (#16090)

* Distillbert type - squash

* Update src/transformers/models/distilbert/modeling_distilbert.py

Undo cleanup

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Remove type

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
John Ryan 2022-03-16 14:54:50 +00:00 committed by GitHub
parent 0b8b06185d
commit 5bdf3313ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,6 +19,7 @@
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
@ -26,6 +27,8 @@ from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.configuration_utils import PretrainedConfig
from ...activations import get_activation
from ...deepspeed import is_deepspeed_zero3_enabled
from ...file_utils import (
@ -72,7 +75,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def create_sinusoidal_embeddings(n_pos, dim, out):
def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
if is_deepspeed_zero3_enabled():
import deepspeed
@ -83,7 +86,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
def _create_sinusoidal_embeddings(n_pos, dim, out):
def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
@ -92,7 +95,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
class Embeddings(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
@ -108,7 +111,7 @@ class Embeddings(nn.Module):
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids):
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Parameters:
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
@ -137,7 +140,7 @@ class Embeddings(nn.Module):
class MultiHeadSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.n_heads = config.n_heads
@ -151,9 +154,9 @@ class MultiHeadSelfAttention(nn.Module):
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.pruned_heads = set()
self.pruned_heads: Set[int] = set()
def prune_heads(self, heads):
def prune_heads(self, heads: List[int]):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
@ -168,7 +171,15 @@ class MultiHeadSelfAttention(nn.Module):
self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
@ -189,11 +200,11 @@ class MultiHeadSelfAttention(nn.Module):
mask_reshp = (bs, 1, 1, k_length)
def shape(x):
def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
@ -224,7 +235,7 @@ class MultiHeadSelfAttention(nn.Module):
class FFN(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dropout = nn.Dropout(p=config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -233,10 +244,10 @@ class FFN(nn.Module):
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
self.activation = get_activation(config.activation)
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
def ff_chunk(self, input):
def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
x = self.lin1(input)
x = self.activation(x)
x = self.lin2(x)
@ -245,7 +256,7 @@ class FFN(nn.Module):
class TransformerBlock(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
assert config.dim % config.n_heads == 0
@ -256,7 +267,13 @@ class TransformerBlock(nn.Module):
self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
"""
Parameters:
x: torch.tensor(bs, seq_length, dim)
@ -284,7 +301,7 @@ class TransformerBlock(nn.Module):
# Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,)
if output_attentions:
@ -293,14 +310,20 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.n_layers = config.n_layers
self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
def forward(
self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None
): # docstyle-ignore
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: Optional[bool] = None,
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore
"""
Parameters:
x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
@ -357,7 +380,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None
base_model_prefix = "distilbert"
def _init_weights(self, module):
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
@ -432,7 +455,7 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
DISTILBERT_START_DOCSTRING,
)
class DistilBertModel(DistilBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.embeddings = Embeddings(config) # Embeddings
@ -489,13 +512,13 @@ class DistilBertModel(DistilBertPreTrainedModel):
# move position_embeddings to correct device
self.embeddings.position_embeddings.to(self.device)
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings
def set_input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings: nn.Embedding):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
@ -512,14 +535,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -560,7 +583,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.activation = get_activation(config.activation)
@ -595,10 +618,10 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
def get_output_embeddings(self):
def get_output_embeddings(self) -> nn.Module:
return self.vocab_projector
def set_output_embeddings(self, new_embeddings):
def set_output_embeddings(self, new_embeddings: nn.Module):
self.vocab_projector = new_embeddings
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
@ -610,15 +633,15 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -666,7 +689,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
@ -708,15 +731,15 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@ -784,7 +807,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.distilbert = DistilBertModel(config)
@ -824,16 +847,16 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor, ...]]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
@ -901,7 +924,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.num_labels = config.num_labels
@ -941,15 +964,15 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[TokenClassifierOutput, Tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@ -996,7 +1019,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.distilbert = DistilBertModel(config)
@ -1033,15 +1056,15 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
@replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,