mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixing model to add torchscript, embedding resizing, head pruning and masking + tests
This commit is contained in:
parent
62df4ba59a
commit
c9bce1811c
@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # outputs, (hidden states), (attentions)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -30,7 +30,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings
|
||||
from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings, prune_linear_layer
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -92,6 +92,17 @@ class DilBertConfig(PretrainedConfig):
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
" or the path to a pretrained model config file (str)")
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.hidden_dim
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.n_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.n_layers
|
||||
|
||||
|
||||
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
|
||||
@ -163,11 +174,30 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||
|
||||
def prune_heads(self, heads):
|
||||
attention_head_size = self.dim // self.n_heads
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_heads, attention_head_size)
|
||||
for head in heads:
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
# Prune linear layers
|
||||
self.q_lin = prune_linear_layer(self.q_lin, index)
|
||||
self.k_lin = prune_linear_layer(self.k_lin, index)
|
||||
self.v_lin = prune_linear_layer(self.v_lin, index)
|
||||
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
|
||||
# Update hyper params
|
||||
self.n_heads = self.n_heads - len(heads)
|
||||
self.dim = attention_head_size * self.n_heads
|
||||
|
||||
def forward(self,
|
||||
query: torch.tensor,
|
||||
key: torch.tensor,
|
||||
value: torch.tensor,
|
||||
mask: torch.tensor):
|
||||
mask: torch.tensor,
|
||||
head_mask: torch.tensor = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -185,10 +215,10 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
"""
|
||||
bs, q_length, dim = query.size()
|
||||
k_length = key.size(1)
|
||||
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
assert key.size() == value.size()
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
# assert key.size() == value.size()
|
||||
|
||||
dim_per_head = dim // self.n_heads
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
assert 2 <= mask.dim() <= 3
|
||||
causal = (mask.dim() == 3)
|
||||
@ -200,7 +230,7 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
|
||||
def unshape(x):
|
||||
""" group heads """
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, dim)
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
||||
|
||||
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
||||
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
||||
@ -213,6 +243,11 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
|
||||
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
|
||||
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
weights = weights * head_mask
|
||||
|
||||
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
|
||||
context = unshape(context) # (bs, q_length, dim)
|
||||
context = self.out_lin(context) # (bs, q_length, dim)
|
||||
@ -229,7 +264,7 @@ class FFN(nn.Module):
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
|
||||
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
|
||||
assert config.activation in ['relu', 'gelu'], ValueError(f"activation ({config.activation}) must be in ['relu', 'gelu']")
|
||||
assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation)
|
||||
self.activation = gelu if config.activation == 'gelu' else nn.ReLU()
|
||||
|
||||
def forward(self,
|
||||
@ -262,7 +297,8 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
x: torch.tensor,
|
||||
attn_mask: torch.tensor = None):
|
||||
attn_mask: torch.tensor = None,
|
||||
head_mask: torch.tensor = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -277,7 +313,7 @@ class TransformerBlock(nn.Module):
|
||||
The output of the transformer block contextualization.
|
||||
"""
|
||||
# Self-Attention
|
||||
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
|
||||
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
||||
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
||||
@ -294,6 +330,7 @@ class TransformerBlock(nn.Module):
|
||||
output = (sa_weights,) + output
|
||||
return output
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self,
|
||||
config):
|
||||
@ -307,7 +344,8 @@ class Transformer(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
x: torch.tensor,
|
||||
attn_mask: torch.tensor = None):
|
||||
attn_mask: torch.tensor = None,
|
||||
head_mask: torch.tensor = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -331,14 +369,24 @@ class Transformer(nn.Module):
|
||||
all_attentions = ()
|
||||
|
||||
hidden_state = x
|
||||
for _, layer_module in enumerate(self.layer):
|
||||
hidden_state = layer_module(x=hidden_state, attn_mask=attn_mask)
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
layer_outputs = layer_module(x=hidden_state,
|
||||
attn_mask=attn_mask,
|
||||
head_mask=head_mask[i])
|
||||
hidden_state = layer_outputs[-1]
|
||||
|
||||
if self.output_attentions:
|
||||
attentions, hidden_state = hidden_state
|
||||
assert len(layer_outputs) == 2
|
||||
attentions = layer_outputs[0]
|
||||
all_attentions = all_attentions + (attentions,)
|
||||
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
||||
assert type(hidden_state) == tuple
|
||||
hidden_state = hidden_state[0]
|
||||
else:
|
||||
assert len(layer_outputs) == 1
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
outputs = (hidden_state,)
|
||||
@ -346,7 +394,7 @@ class Transformer(nn.Module):
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
|
||||
@ -378,9 +426,21 @@ class DilBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
DILBERT_START_DOCSTRING = r"""
|
||||
Smaller, faster, cheaper, lighter: DilBERT
|
||||
DilBERT is a small, fast, cheap and light Transformer model
|
||||
trained by distilling Bert base. It has 40% less parameters than
|
||||
`bert-base-uncased`, runs 60% faster while preserving over 95% of
|
||||
Bert's performances as measured on the GLUE language understanding benchmark.
|
||||
|
||||
For more information on DilBERT, you should check TODO(Link): Link to Medium
|
||||
Here are the differences between the interface of Bert and DilBert:
|
||||
|
||||
- DilBert doesn't have `token_type_ids`, you don't need to indicate which token belong to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`)
|
||||
- DilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option.
|
||||
|
||||
For more information on DilBERT, please refer to our
|
||||
`detailed blog post`_
|
||||
|
||||
.. _`detailed blog post`:
|
||||
https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-dilbert-a-distilled-version-of-bert-8cf3380435b5
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.DilBertConfig`): Model configuration class with all the parameters of the model.
|
||||
@ -399,31 +459,35 @@ DILBERT_INPUTS_DOCSTRING = r"""
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertModel(DilBertPreTrainedModel):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
Sequences of token ids.
|
||||
attention_mask: torch.tensor(bs, seq_length)
|
||||
Attention mask on the sequences. Optional: If None, it's like there was no padding.
|
||||
|
||||
Outputs
|
||||
-------
|
||||
hidden_state: torch.tensor(bs, seq_length, dim)
|
||||
Sequence of hiddens states in the last (top) layer
|
||||
pooled_output: torch.tensor(bs, dim)
|
||||
Pooled output: for DilBert, the pooled output is simply the hidden state of the [CLS] token.
|
||||
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
|
||||
Tuple of length n_layers with the hidden states from each layer.
|
||||
Optional: only if output_hidden_states=True
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if output_attentions=True
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
|
||||
model = DilBertModel.from_pretrained('dilbert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertModel, self).__init__(config)
|
||||
@ -433,47 +497,83 @@ class DilBertModel(DilBertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.transformer.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None):
|
||||
attention_mask: torch.tensor = None,
|
||||
head_mask: torch.tensor = None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||
tfmr_output = self.transformer(x=embedding_output,
|
||||
attn_mask=attention_mask)
|
||||
attn_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_state = tfmr_output[0]
|
||||
pooled_output = hidden_state[:, 0]
|
||||
output = (hidden_state, pooled_output) + tfmr_output[1:]
|
||||
output = (hidden_state, ) + tfmr_output[1:]
|
||||
|
||||
return output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||
|
||||
return output # hidden_state, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
@add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """,
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertForMaskedLM(DilBertPreTrainedModel):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
Token ids.
|
||||
attention_mask: torch.tensor(bs, seq_length)
|
||||
Attention mask. Optional: If None, it's like there was no padding.
|
||||
masked_lm_labels: torch.tensor(bs, seq_length)
|
||||
The masked language modeling labels. Optional: If None, no loss is computed.
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
|
||||
model = DilBertForMaskedLM.from_pretrained('dilbert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, masked_lm_labels=input_ids)
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
Outputs
|
||||
-------
|
||||
mlm_loss: torch.tensor(1,)
|
||||
Masked Language Modeling loss to optimize.
|
||||
Optional: only if `masked_lm_labels` is not None
|
||||
prediction_logits: torch.tensor(bs, seq_length, voc_size)
|
||||
Token prediction logits
|
||||
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
|
||||
Tuple of length n_layers with the hidden states from each layer.
|
||||
Optional: only if `output_hidden_states`=True
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if `output_attentions`=True
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertForMaskedLM, self).__init__(config)
|
||||
@ -491,59 +591,68 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
Tying the weights of the vocabulary projection to the base token embeddings.
|
||||
"""
|
||||
if self.config.tie_weights_:
|
||||
self.vocab_projector.weight = self.dilbert.embeddings.word_embeddings.weight
|
||||
self._tie_or_clone_weights(self.vocab_projector,
|
||||
self.dilbert.embeddings.word_embeddings)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
masked_lm_labels: torch.tensor = None):
|
||||
masked_lm_labels: torch.tensor = None,
|
||||
head_mask: torch.tensor = None):
|
||||
dlbrt_output = self.dilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
||||
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
|
||||
|
||||
outputs = (prediction_logits, ) + dlbrt_output[2:]
|
||||
outputs = (prediction_logits, ) + dlbrt_output[1:]
|
||||
if masked_lm_labels is not None:
|
||||
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
|
||||
masked_lm_labels.view(-1))
|
||||
outputs = (mlm_loss,) + outputs
|
||||
|
||||
return outputs # (mlm_loss), prediction_logits, (hidden_states), (attentions)
|
||||
return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""DilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertForSequenceClassification(DilBertPreTrainedModel):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
Token ids.
|
||||
attention_mask: torch.tensor(bs, seq_length)
|
||||
Attention mask. Optional: If None, it's like there was no padding.
|
||||
labels: torch.tensor(bs,)
|
||||
Classification Labels: Optional: If None, no loss will be computed.
|
||||
|
||||
Outputs
|
||||
-------
|
||||
loss: torch.tensor(1)
|
||||
Sequence classification loss.
|
||||
Optional: Is computed only if `labels` is not None.
|
||||
logits: torch.tensor(bs, seq_length)
|
||||
Classification (or regression if config.num_labels==1) scores
|
||||
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
|
||||
Tuple of length n_layers with the hidden states from each layer.
|
||||
Optional: only if `output_hidden_states`=True
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if `output_attentions`=True
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
|
||||
model = DilBertForSequenceClassification.from_pretrained('dilbert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertForSequenceClassification, self).__init__(config)
|
||||
@ -559,16 +668,19 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
labels: torch.tensor = None):
|
||||
labels: torch.tensor = None,
|
||||
head_mask: torch.tensor = None):
|
||||
dilbert_output = self.dilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
pooled_output = dilbert_output[1] # (bs, dim)
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_state = dilbert_output[0] # (bs, seq_len, dim)
|
||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
|
||||
pooled_output = self.dropout(pooled_output) # (bs, dim)
|
||||
logits = self.classifier(pooled_output) # (bs, dim)
|
||||
|
||||
outputs = (logits,) + dilbert_output[2:]
|
||||
outputs = (logits,) + dilbert_output[1:]
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
loss_fct = nn.MSELoss()
|
||||
@ -580,43 +692,46 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""DilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertForQuestionAnswering(DilBertPreTrainedModel):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
Token ids.
|
||||
attention_mask: torch.tensor(bs, seq_length)
|
||||
Attention mask. Optional: If None, it's like there was no padding.
|
||||
start_positions: torch,tensor(bs)
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
Optional: if None, no loss is computed.
|
||||
end_positions: torch,tensor(bs)
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
Optional: if None, no loss is computed.
|
||||
|
||||
Outputs
|
||||
-------
|
||||
loss: torch.tensor(1)
|
||||
Question answering loss.
|
||||
Optional: Is computed only if `start_positions` and `end_positions` are not None.
|
||||
start_logits: torch.tensor(bs, seq_length)
|
||||
Span-start scores.
|
||||
end_logits: torch.tensor(bs, seq_length)
|
||||
Spand-end scores.
|
||||
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
|
||||
Tuple of length n_layers with the hidden states from each layer.
|
||||
Optional: only if `output_hidden_states`=True
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if `output_attentions`=True
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
|
||||
model = DilBertForQuestionAnswering.from_pretrained('dilbert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
start_positions = torch.tensor([1])
|
||||
end_positions = torch.tensor([3])
|
||||
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||
loss, start_scores, end_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertForQuestionAnswering, self).__init__(config)
|
||||
@ -632,9 +747,11 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
start_positions: torch.tensor = None,
|
||||
end_positions: torch.tensor = None):
|
||||
end_positions: torch.tensor = None,
|
||||
head_mask: torch.tensor = None):
|
||||
dilbert_output = self.dilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_states = dilbert_output[0] # (bs, max_query_len, dim)
|
||||
|
||||
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
|
||||
@ -643,7 +760,7 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
|
||||
start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
|
||||
end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
|
||||
|
||||
outputs = (start_logits, end_logits,) + dilbert_output[2:]
|
||||
outputs = (start_logits, end_logits,) + dilbert_output[1:]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
|
@ -21,7 +21,7 @@ import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
|
||||
DilBertForQuestionAnswering, DilBertForSequenceClassification)
|
||||
DilBertForQuestionAnswering, DilBertForSequenceClassification)
|
||||
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
||||
@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
|
||||
DilBertForSequenceClassification)
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
class DilBertModelTester(object):
|
||||
|
||||
@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertModel(config=config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, input_mask)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
(sequence_output,) = model(input_ids, input_mask)
|
||||
(sequence_output,) = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, input_mask, token_labels)
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
|
Loading…
Reference in New Issue
Block a user