updating head masking, readme and docstrings

This commit is contained in:
thomwolf 2019-06-17 15:51:28 +02:00
parent 965f172de6
commit 33d3db5c43
7 changed files with 220 additions and 43 deletions

View File

@ -474,7 +474,7 @@ Here is a detailed documentation of the classes in the package and how to use th
To load one of Google AI's, OpenAI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
```python
model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None, from_tf=False, state_dict=None, *input, **kwargs)
```
where
@ -505,7 +505,12 @@ where
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining`, `OpenAIGPTModel`, `TransfoXLModel`, `GPT2LMHeadModel` (saved with the usual `torch.save()`)
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information).
- `from_tf`: should we load the weights from a locally saved TensorFlow checkpoint
- `state_dict`: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
- `*inputs`, `**kwargs`: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification)
`Uncased` means that the text has been lowercased before WordPiece tokenization, e.g., `John Smith` becomes `john smith`. The Uncased model also strips out any accent markers. `Cased` means that the true case and accent markers are preserved. Typically, the Uncased model is better unless you know that case information is important for your task (e.g., Named Entity Recognition or Part-of-Speech tagging). For information about the Multilingual and Chinese model, see the [Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md) or the original TensorFlow repository.
@ -631,6 +636,13 @@ These configuration classes contains a few utilities to load and save configurat
`BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large).
Instantiation:
The model can be instantiated with the following arguments:
- `config`: a `BertConfig` class instance with the configuration to build a new model.
- `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
- `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. This can be used to compute head importance metrics. Default: False
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
@ -639,6 +651,7 @@ We detail them here. This model takes as *inputs*:
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
This model *outputs* a tuple composed of:
@ -756,6 +769,13 @@ where total_tokens_embeddings can be obtained as config.total_tokens_embeddings
`total_tokens_embeddings = config.vocab_size + config.n_special`
You should use the associate indices to index the embeddings.
Instantiation:
The model can be instantiated with the following arguments:
- `config`: a `OpenAIConfig` class instance with the configuration to build a new model.
- `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
- `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. This can be used to compute head importance metrics. Default: False
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
@ -766,9 +786,10 @@ We detail them here. This model takes as *inputs*:
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
This model *outputs*:
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
#### 10. `OpenAIGPTLMHeadModel`
@ -848,6 +869,13 @@ all_hidden_states = lower_hidden_states + [hidden_states]
`GPT2Model` is the OpenAI GPT-2 Transformer model with a layer of summed token and position embeddings followed by a series of 12 identical self-attention blocks.
Instantiation:
The model can be instantiated with the following arguments:
- `config`: a `GPT2Config` class instance with the configuration to build a new model.
- `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
- `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. This can be used to compute head importance metrics. Default: False
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
@ -859,9 +887,10 @@ We detail them here. This model takes as *inputs*:
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
- `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below).
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
This model *outputs*:
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
#### 15. `GPT2LMHeadModel`

View File

@ -477,8 +477,8 @@ class BertEncoder(nn.Module):
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None):
all_encoder_layers = []
all_attentions = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask, head_mask)
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
if self.output_attentions:
attentions, hidden_states = hidden_states
all_attentions.append(attentions)
@ -618,6 +618,9 @@ class BertPreTrainedModel(nn.Module):
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
. `bert-base-german-cased`
. `bert-large-uncased-whole-word-masking`
. `bert-large-cased-whole-word-masking`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
@ -744,7 +747,10 @@ class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@ -758,6 +764,9 @@ class BertModel(BertPreTrainedModel):
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
@ -828,15 +837,19 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1 in head_mask indicate we need to mask the head
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape num_hidden_layers x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
@ -861,7 +874,10 @@ class BertForPreTraining(BertPreTrainedModel):
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@ -880,6 +896,8 @@ class BertForPreTraining(BertPreTrainedModel):
`next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
@ -937,7 +955,10 @@ class BertForMaskedLM(BertPreTrainedModel):
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@ -953,6 +974,12 @@ class BertForMaskedLM(BertPreTrainedModel):
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`head_mask`: an optional torch.LongTensor of shape [num_heads] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `masked_lm_labels` is not `None`:
@ -1006,7 +1033,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@ -1022,6 +1052,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `next_sentence_label` is not `None`:
@ -1077,7 +1109,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
@ -1093,6 +1128,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `labels` is not `None`:
@ -1150,7 +1187,10 @@ class BertForMultipleChoice(BertPreTrainedModel):
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
@ -1166,6 +1206,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `labels` is not `None`:
@ -1226,7 +1268,10 @@ class BertForTokenClassification(BertPreTrainedModel):
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
@ -1242,6 +1287,8 @@ class BertForTokenClassification(BertPreTrainedModel):
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [0, ..., num_labels].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `labels` is not `None`:
@ -1306,7 +1353,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
the sequence output that computes start_logits and end_logits
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@ -1325,6 +1375,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `start_positions` and `end_positions` are not `None`:

View File

@ -610,7 +610,10 @@ class GPT2Model(GPT2PreTrainedModel):
You should use the associate indices to index the embeddings.
Params:
config: a GPT2Config class instance with the configuration to build a new model
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
@ -625,10 +628,12 @@ class GPT2Model(GPT2PreTrainedModel):
`past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
(key and values in the attention blocks) to speed up sequential decoding
(this is the presents output of the model, cf. below).
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs a tuple consisting of:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
`hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings)
as torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
`presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
torch.FloatTensors. They can be reused to speed up sequential decoding.
@ -698,13 +703,17 @@ class GPT2Model(GPT2PreTrainedModel):
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.n_layer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
@ -725,9 +734,9 @@ class GPT2Model(GPT2PreTrainedModel):
presents = []
all_attentions = []
all_hidden_states = []
for block, layer_past in zip(self.h, past):
for i, (block, layer_past) in enumerate(zip(self.h, past)):
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = block(hidden_states, layer_past, head_mask)
outputs = block(hidden_states, layer_past, head_mask[i])
if self.output_attentions:
attentions, hidden_states, present = outputs
all_attentions.append(attentions)
@ -746,7 +755,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
@ -764,6 +776,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
`past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
(key and values in the attention blocks) to speed up sequential decoding
(this is the presents output of the model, cf. below).
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `lm_labels` is not `None`:
@ -828,7 +842,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
@ -850,6 +867,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
`past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
(key and values in the attention blocks) to speed up sequential decoding
(this is the presents output of the model, cf. below).
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `lm_labels` and `multiple_choice_labels` are not `None`:

View File

@ -613,7 +613,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
You should use the associate indices to index the embeddings.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
`config`: a OpenAIGPTConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
@ -625,10 +628,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
`hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings)
as torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
@ -694,13 +699,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.n_layer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
@ -720,8 +729,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions = []
all_hidden_states = [hidden_states.view(*output_shape)]
for block in self.h:
outputs = block(hidden_states, head_mask)
for i, block in enumerate(self.h):
outputs = block(hidden_states, head_mask[i])
if self.output_attentions:
attentions, hidden_states = outputs
all_attentions.append(attentions)
@ -755,7 +764,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
You should use the associate indices to index the embeddings.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
`config`: a OpenAIGPTConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
@ -770,6 +782,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `lm_labels` is not `None`:
@ -847,7 +861,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
You should use the associate indices to index the embeddings.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
`config`: a OpenAIGPTConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
@ -866,6 +883,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
is only computed for the labels set in [0, ..., total_tokens_embeddings]
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `lm_labels` and `multiple_choice_labels` are not `None`:

View File

@ -215,9 +215,9 @@ class GPT2ModelTest(unittest.TestCase):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.ones(self.n_head).to(input_ids.device)
head_mask[0] = 0.0
head_mask[-1] = 0.0 # Mask all but the first and last heads
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
@ -246,6 +246,25 @@ class GPT2ModelTest(unittest.TestCase):
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_gpt2_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):

View File

@ -188,9 +188,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.ones(self.n_head).to(input_ids.device)
head_mask[0] = 0.0
head_mask[-1] = 0.0 # Mask all but the first and last heads
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
@ -219,6 +219,26 @@ class OpenAIGPTModelTest(unittest.TestCase):
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_openai_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):

View File

@ -305,9 +305,9 @@ class BertModelTest(unittest.TestCase):
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.ones(self.num_attention_heads).to(input_ids.device)
head_mask[0] = 0.0
head_mask[-1] = 0.0 # Mask all but the first and last heads
head_mask = torch.zeros(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
if isinstance(model, BertModel):
@ -333,6 +333,25 @@ class BertModelTest(unittest.TestCase):
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,