fixing model to add torchscript, embedding resizing, head pruning and masking + tests

This commit is contained in:
thomwolf 2019-08-28 13:22:45 +02:00
parent 62df4ba59a
commit c9bce1811c
3 changed files with 253 additions and 138 deletions

View File

@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -30,7 +30,7 @@ import numpy as np
import torch
import torch.nn as nn
from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings
from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings, prune_linear_layer
import logging
logger = logging.getLogger(__name__)
@ -92,6 +92,17 @@ class DilBertConfig(PretrainedConfig):
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
@property
def hidden_size(self):
return self.hidden_dim
@property
def num_attention_heads(self):
return self.n_heads
@property
def num_hidden_layers(self):
return self.n_layers
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
@ -163,11 +174,30 @@ class MultiHeadSelfAttention(nn.Module):
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)
self.v_lin = prune_linear_layer(self.v_lin, index)
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
def forward(self,
query: torch.tensor,
key: torch.tensor,
value: torch.tensor,
mask: torch.tensor):
mask: torch.tensor,
head_mask: torch.tensor = None):
"""
Parameters
----------
@ -185,10 +215,10 @@ class MultiHeadSelfAttention(nn.Module):
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
assert key.size() == value.size()
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size()
dim_per_head = dim // self.n_heads
dim_per_head = self.dim // self.n_heads
assert 2 <= mask.dim() <= 3
causal = (mask.dim() == 3)
@ -200,7 +230,7 @@ class MultiHeadSelfAttention(nn.Module):
def unshape(x):
""" group heads """
return x.transpose(1, 2).contiguous().view(bs, -1, dim)
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
@ -213,6 +243,11 @@ class MultiHeadSelfAttention(nn.Module):
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
@ -229,7 +264,7 @@ class FFN(nn.Module):
self.dropout = nn.Dropout(p=config.dropout)
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ['relu', 'gelu'], ValueError(f"activation ({config.activation}) must be in ['relu', 'gelu']")
assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation)
self.activation = gelu if config.activation == 'gelu' else nn.ReLU()
def forward(self,
@ -262,7 +297,8 @@ class TransformerBlock(nn.Module):
def forward(self,
x: torch.tensor,
attn_mask: torch.tensor = None):
attn_mask: torch.tensor = None,
head_mask: torch.tensor = None):
"""
Parameters
----------
@ -277,7 +313,7 @@ class TransformerBlock(nn.Module):
The output of the transformer block contextualization.
"""
# Self-Attention
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask)
if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
@ -294,6 +330,7 @@ class TransformerBlock(nn.Module):
output = (sa_weights,) + output
return output
class Transformer(nn.Module):
def __init__(self,
config):
@ -307,7 +344,8 @@ class Transformer(nn.Module):
def forward(self,
x: torch.tensor,
attn_mask: torch.tensor = None):
attn_mask: torch.tensor = None,
head_mask: torch.tensor = None):
"""
Parameters
----------
@ -331,14 +369,24 @@ class Transformer(nn.Module):
all_attentions = ()
hidden_state = x
for _, layer_module in enumerate(self.layer):
hidden_state = layer_module(x=hidden_state, attn_mask=attn_mask)
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module(x=hidden_state,
attn_mask=attn_mask,
head_mask=head_mask[i])
hidden_state = layer_outputs[-1]
if self.output_attentions:
attentions, hidden_state = hidden_state
assert len(layer_outputs) == 2
attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(hidden_state) == tuple
hidden_state = hidden_state[0]
else:
assert len(layer_outputs) == 1
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
@ -346,7 +394,7 @@ class Transformer(nn.Module):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs
return outputs # last-layer hidden state, (all hidden states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
@ -378,9 +426,21 @@ class DilBertPreTrainedModel(PreTrainedModel):
DILBERT_START_DOCSTRING = r"""
Smaller, faster, cheaper, lighter: DilBERT
DilBERT is a small, fast, cheap and light Transformer model
trained by distilling Bert base. It has 40% less parameters than
`bert-base-uncased`, runs 60% faster while preserving over 95% of
Bert's performances as measured on the GLUE language understanding benchmark.
For more information on DilBERT, you should check TODO(Link): Link to Medium
Here are the differences between the interface of Bert and DilBert:
- DilBert doesn't have `token_type_ids`, you don't need to indicate which token belong to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`)
- DilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option.
For more information on DilBERT, please refer to our
`detailed blog post`_
.. _`detailed blog post`:
https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-dilbert-a-distilled-version-of-bert-8cf3380435b5
Parameters:
config (:class:`~pytorch_transformers.DilBertConfig`): Model configuration class with all the parameters of the model.
@ -399,31 +459,35 @@ DILBERT_INPUTS_DOCSTRING = r"""
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertModel(DilBertPreTrainedModel):
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Sequences of token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask on the sequences. Optional: If None, it's like there was no padding.
Outputs
-------
hidden_state: torch.tensor(bs, seq_length, dim)
Sequence of hiddens states in the last (top) layer
pooled_output: torch.tensor(bs, dim)
Pooled output: for DilBert, the pooled output is simply the hidden state of the [CLS] token.
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if output_hidden_states=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
model = DilBertModel.from_pretrained('dilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(DilBertModel, self).__init__(config)
@ -433,47 +497,83 @@ class DilBertModel(DilBertPreTrainedModel):
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None):
attention_mask: torch.tensor = None,
head_mask: torch.tensor = None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=embedding_output,
attn_mask=attention_mask)
attn_mask=attention_mask,
head_mask=head_mask)
hidden_state = tfmr_output[0]
pooled_output = hidden_state[:, 0]
output = (hidden_state, pooled_output) + tfmr_output[1:]
output = (hidden_state, ) + tfmr_output[1:]
return output # last-layer hidden-state, (all hidden_states), (all attentions)
return output # hidden_state, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForMaskedLM(DilBertPreTrainedModel):
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask. Optional: If None, it's like there was no padding.
masked_lm_labels: torch.tensor(bs, seq_length)
The masked language modeling labels. Optional: If None, no loss is computed.
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
model = DilBertForMaskedLM.from_pretrained('dilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
Outputs
-------
mlm_loss: torch.tensor(1,)
Masked Language Modeling loss to optimize.
Optional: only if `masked_lm_labels` is not None
prediction_logits: torch.tensor(bs, seq_length, voc_size)
Token prediction logits
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if `output_hidden_states`=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
def __init__(self, config):
super(DilBertForMaskedLM, self).__init__(config)
@ -491,59 +591,68 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
Tying the weights of the vocabulary projection to the base token embeddings.
"""
if self.config.tie_weights_:
self.vocab_projector.weight = self.dilbert.embeddings.word_embeddings.weight
self._tie_or_clone_weights(self.vocab_projector,
self.dilbert.embeddings.word_embeddings)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None):
masked_lm_labels: torch.tensor = None,
head_mask: torch.tensor = None):
dlbrt_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask)
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
outputs = (prediction_logits, ) + dlbrt_output[2:]
outputs = (prediction_logits, ) + dlbrt_output[1:]
if masked_lm_labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
masked_lm_labels.view(-1))
outputs = (mlm_loss,) + outputs
return outputs # (mlm_loss), prediction_logits, (hidden_states), (attentions)
return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions)
@add_start_docstrings("""DilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForSequenceClassification(DilBertPreTrainedModel):
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask. Optional: If None, it's like there was no padding.
labels: torch.tensor(bs,)
Classification Labels: Optional: If None, no loss will be computed.
Outputs
-------
loss: torch.tensor(1)
Sequence classification loss.
Optional: Is computed only if `labels` is not None.
logits: torch.tensor(bs, seq_length)
Classification (or regression if config.num_labels==1) scores
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if `output_hidden_states`=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
model = DilBertForSequenceClassification.from_pretrained('dilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(DilBertForSequenceClassification, self).__init__(config)
@ -559,16 +668,19 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
labels: torch.tensor = None):
labels: torch.tensor = None,
head_mask: torch.tensor = None):
dilbert_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask)
pooled_output = dilbert_output[1] # (bs, dim)
attention_mask=attention_mask,
head_mask=head_mask)
hidden_state = dilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + dilbert_output[2:]
outputs = (logits,) + dilbert_output[1:]
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
@ -580,43 +692,46 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""DilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForQuestionAnswering(DilBertPreTrainedModel):
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask. Optional: If None, it's like there was no padding.
start_positions: torch,tensor(bs)
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Optional: if None, no loss is computed.
end_positions: torch,tensor(bs)
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Optional: if None, no loss is computed.
Outputs
-------
loss: torch.tensor(1)
Question answering loss.
Optional: Is computed only if `start_positions` and `end_positions` are not None.
start_logits: torch.tensor(bs, seq_length)
Span-start scores.
end_logits: torch.tensor(bs, seq_length)
Spand-end scores.
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if `output_hidden_states`=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased')
model = DilBertForQuestionAnswering.from_pretrained('dilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(DilBertForQuestionAnswering, self).__init__(config)
@ -632,9 +747,11 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
start_positions: torch.tensor = None,
end_positions: torch.tensor = None):
end_positions: torch.tensor = None,
head_mask: torch.tensor = None):
dilbert_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask)
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = dilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
@ -643,7 +760,7 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
outputs = (start_logits, end_logits,) + dilbert_output[2:]
outputs = (start_logits, end_logits,) + dilbert_output[1:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:

View File

@ -21,7 +21,7 @@ import shutil
import pytest
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
DilBertForQuestionAnswering, DilBertForSequenceClassification)
DilBertForQuestionAnswering, DilBertForSequenceClassification)
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
DilBertForSequenceClassification)
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
class DilBertModelTester(object):
@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, input_mask)
sequence_output, pooled_output = model(input_ids)
(sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, input_mask, token_labels)
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,