From 96c4d3d9885a09340a10869949c7c9bea4bfb5c4 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 17 Jun 2019 12:17:26 +0200 Subject: [PATCH] add head masking tests --- pytorch_pretrained_bert/modeling.py | 128 ++++++++++++++++++++++------ tests/modeling_test.py | 42 +++++++++ 2 files changed, 142 insertions(+), 28 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 11a7191df5d..950f96744ce 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -51,6 +51,32 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { BERT_CONFIG_NAME = 'bert_config.json' TF_WEIGHTS_NAME = 'model.ckpt' +def prune_linear_layer(layer, index, dim=-1): + """ Prune a linear layer (a model parameters) to keep only entries in index. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + dim = (dim+100) % 2 + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if layer.bias is not None: + if dim == 1: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + def load_tf_weights_in_bert(model, tf_checkpoint_path): """ Load tf checkpoints in a pytorch model """ @@ -329,12 +355,7 @@ class BertSelfAttention(nn.Module): attention_probs = self.dropout(attention_probs) # Mask heads if we want to - # attention_probs has shape bsz x n_heads x N x N if head_mask is not None: - if head_mask.dim() == 1: - head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - elif head_mask.dim() == 2: - head_mask.unsqueeze(-1).unsqueeze(-1) # We can define heads to mask for each instance in the batch attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) @@ -365,12 +386,28 @@ class BertSelfOutput(nn.Module): class BertAttention(nn.Module): - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertAttention, self).__init__() self.output_attentions = output_attentions - self.self = BertSelfAttention(config, output_attentions=output_attentions) + self.self = BertSelfAttention(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.output = BertSelfOutput(config) + def prune_heads(self, heads): + mask = torch.ones(self.self.n_heads, self.self.d_head) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=0) + # Update hyper params + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + def forward(self, input_tensor, attention_mask, head_mask=None): self_output = self.self(input_tensor, attention_mask, head_mask) if self.output_attentions: @@ -411,10 +448,11 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertLayer, self).__init__() self.output_attentions = output_attentions - self.attention = BertAttention(config, output_attentions=output_attentions) + self.attention = BertAttention(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) @@ -430,10 +468,11 @@ class BertLayer(nn.Module): class BertEncoder(nn.Module): - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertEncoder, self).__init__() self.output_attentions = output_attentions - layer = BertLayer(config, output_attentions=output_attentions) + layer = BertLayer(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None): @@ -741,14 +780,28 @@ class BertModel(BertPreTrainedModel): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertModel, self).__init__(config) self.output_attentions = output_attentions self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config, output_attentions=output_attentions) + self.encoder = BertEncoder(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) + def prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_multihead_outputs(self): + """ Gather all multi-head outputs. + Return: list (layers) of multihead module outputs with gradients + """ + return [layer.attention.self.multihead_output for layer in self.encoder.layer] + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -770,6 +823,17 @@ class BertModel(BertPreTrainedModel): extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + # Prepare head mask if needed + # 1 in head_mask indicate we need to mask the head + # attention_probs has shape bsz x n_heads x N x N + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + head_mask = (1.0 - head_mask) + embedding_output = self.embeddings(input_ids, token_type_ids) encoded_layers = self.encoder(embedding_output, extended_attention_mask, @@ -836,10 +900,11 @@ class BertForPreTraining(BertPreTrainedModel): masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertForPreTraining, self).__init__(config) self.output_attentions = output_attentions - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) @@ -905,10 +970,11 @@ class BertForMaskedLM(BertPreTrainedModel): masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertForMaskedLM, self).__init__(config) self.output_attentions = output_attentions - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) @@ -974,10 +1040,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel): seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertForNextSentencePrediction, self).__init__(config) self.output_attentions = output_attentions - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.cls = BertOnlyNSPHead(config) self.apply(self.init_bert_weights) @@ -1045,11 +1112,12 @@ class BertForSequenceClassification(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels=2, output_attentions=False): + def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): super(BertForSequenceClassification, self).__init__(config) self.output_attentions = output_attentions self.num_labels = num_labels - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) @@ -1116,11 +1184,12 @@ class BertForMultipleChoice(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_choices=2, output_attentions=False): + def __init__(self, config, num_choices=2, output_attentions=False, keep_multihead_output=False): super(BertForMultipleChoice, self).__init__(config) self.output_attentions = output_attentions self.num_choices = num_choices - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_bert_weights) @@ -1192,11 +1261,12 @@ class BertForTokenClassification(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels=2, output_attentions=False): + def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): super(BertForTokenClassification, self).__init__(config) self.output_attentions = output_attentions self.num_labels = num_labels - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) @@ -1273,14 +1343,16 @@ class BertForQuestionAnswering(BertPreTrainedModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False): + def __init__(self, config, output_attentions=False, keep_multihead_output=False): super(BertForQuestionAnswering, self).__init__(config) self.output_attentions = output_attentions - self.bert = BertModel(config, output_attentions=output_attentions) + self.bert = BertModel(config, output_attentions=output_attentions, + keep_multihead_output=keep_multihead_output) self.qa_outputs = nn.Linear(config.hidden_size, 2) self.apply(self.init_bert_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, head_mask=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, + end_positions=None, head_mask=None): outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) diff --git a/tests/modeling_test.py b/tests/modeling_test.py index 79993ed8402..4c78ead7679 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -293,6 +293,47 @@ class BertModelTest(unittest.TestCase): [self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length]) + def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction, + BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, + BertForTokenClassification): + if model_class in [BertForSequenceClassification, + BertForTokenClassification]: + model = model_class(config=config, + num_labels=self.num_labels, + keep_multihead_output=True) + else: + model = model_class(config=config, keep_multihead_output=True) + model.eval() + head_mask = torch.ones(self.num_attention_heads).to(input_ids.device) + head_mask[0] = 0.0 + head_mask[-1] = 0.0 # Mask all but the first and last heads + output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask) + + if isinstance(model, BertModel): + output = sum(t.sum() for t in output[0]) + elif isinstance(output, (list, tuple)): + output = sum(t.sum() for t in output) + output = output.sum() + output.backward() + multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs() + + self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers) + self.parent.assertListEqual( + list(multihead_outputs[0].size()), + [self.batch_size, self.num_attention_heads, + self.seq_length, self.hidden_size // self.num_attention_heads]) + self.parent.assertEqual( + len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()), + 0) + self.parent.assertEqual( + len(multihead_outputs[0][:, 0, :, :].nonzero()), + self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) + self.parent.assertEqual( + len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()), + self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) + + def test_default(self): self.run_tester(BertModelTest.BertModelTester(self)) @@ -352,6 +393,7 @@ class BertModelTest(unittest.TestCase): tester.check_loss_output(output_result) tester.create_and_check_bert_for_attentions(*config_and_inputs) + tester.create_and_check_bert_for_headmasking(*config_and_inputs) @classmethod def ids_tensor(cls, shape, vocab_size, rng=None, name=None):