mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
add head masking tests
This commit is contained in:
parent
34858ae1d9
commit
96c4d3d988
@ -51,6 +51,32 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def prune_linear_layer(layer, index, dim=-1):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
"""
|
||||
dim = (dim+100) % 2
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if layer.bias is not None:
|
||||
if dim == 1:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
if layer.bias is not None:
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
@ -329,12 +355,7 @@ class BertSelfAttention(nn.Module):
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask.unsqueeze(-1).unsqueeze(-1) # We can define heads to mask for each instance in the batch
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
@ -365,12 +386,28 @@ class BertSelfOutput(nn.Module):
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertAttention, self).__init__()
|
||||
self.output_attentions = output_attentions
|
||||
self.self = BertSelfAttention(config, output_attentions=output_attentions)
|
||||
self.self = BertSelfAttention(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def prune_heads(self, heads):
|
||||
mask = torch.ones(self.self.n_heads, self.self.d_head)
|
||||
for head in heads:
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=0)
|
||||
# Update hyper params
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||
self_output = self.self(input_tensor, attention_mask, head_mask)
|
||||
if self.output_attentions:
|
||||
@ -411,10 +448,11 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertLayer, self).__init__()
|
||||
self.output_attentions = output_attentions
|
||||
self.attention = BertAttention(config, output_attentions=output_attentions)
|
||||
self.attention = BertAttention(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
@ -430,10 +468,11 @@ class BertLayer(nn.Module):
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertEncoder, self).__init__()
|
||||
self.output_attentions = output_attentions
|
||||
layer = BertLayer(config, output_attentions=output_attentions)
|
||||
layer = BertLayer(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None):
|
||||
@ -741,14 +780,28 @@ class BertModel(BertPreTrainedModel):
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertModel, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config, output_attentions=output_attentions)
|
||||
self.encoder = BertEncoder(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.pooler = BertPooler(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def get_multihead_outputs(self):
|
||||
""" Gather all multi-head outputs.
|
||||
Return: list (layers) of multihead module outputs with gradients
|
||||
"""
|
||||
return [layer.attention.self.multihead_output for layer in self.encoder.layer]
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
@ -770,6 +823,17 @@ class BertModel(BertPreTrainedModel):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1 in head_mask indicate we need to mask the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = (1.0 - head_mask)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
@ -836,10 +900,11 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForPreTraining, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
@ -905,10 +970,11 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForMaskedLM, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
@ -974,10 +1040,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForNextSentencePrediction, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.cls = BertOnlyNSPHead(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
@ -1045,11 +1112,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels=2, output_attentions=False):
|
||||
def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
self.apply(self.init_bert_weights)
|
||||
@ -1116,11 +1184,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices=2, output_attentions=False):
|
||||
def __init__(self, config, num_choices=2, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForMultipleChoice, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
self.apply(self.init_bert_weights)
|
||||
@ -1192,11 +1261,12 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels=2, output_attentions=False):
|
||||
def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForTokenClassification, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
self.apply(self.init_bert_weights)
|
||||
@ -1273,14 +1343,16 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(BertForQuestionAnswering, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
|
||||
end_positions=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False,
|
||||
head_mask=head_mask)
|
||||
|
@ -293,6 +293,47 @@ class BertModelTest(unittest.TestCase):
|
||||
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
|
||||
|
||||
|
||||
def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification):
|
||||
if model_class in [BertForSequenceClassification,
|
||||
BertForTokenClassification]:
|
||||
model = model_class(config=config,
|
||||
num_labels=self.num_labels,
|
||||
keep_multihead_output=True)
|
||||
else:
|
||||
model = model_class(config=config, keep_multihead_output=True)
|
||||
model.eval()
|
||||
head_mask = torch.ones(self.num_attention_heads).to(input_ids.device)
|
||||
head_mask[0] = 0.0
|
||||
head_mask[-1] = 0.0 # Mask all but the first and last heads
|
||||
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
|
||||
|
||||
if isinstance(model, BertModel):
|
||||
output = sum(t.sum() for t in output[0])
|
||||
elif isinstance(output, (list, tuple)):
|
||||
output = sum(t.sum() for t in output)
|
||||
output = output.sum()
|
||||
output.backward()
|
||||
multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs()
|
||||
|
||||
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
|
||||
self.parent.assertListEqual(
|
||||
list(multihead_outputs[0].size()),
|
||||
[self.batch_size, self.num_attention_heads,
|
||||
self.seq_length, self.hidden_size // self.num_attention_heads])
|
||||
self.parent.assertEqual(
|
||||
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
|
||||
0)
|
||||
self.parent.assertEqual(
|
||||
len(multihead_outputs[0][:, 0, :, :].nonzero()),
|
||||
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
|
||||
self.parent.assertEqual(
|
||||
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
|
||||
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
|
||||
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(BertModelTest.BertModelTester(self))
|
||||
|
||||
@ -352,6 +393,7 @@ class BertModelTest(unittest.TestCase):
|
||||
tester.check_loss_output(output_result)
|
||||
|
||||
tester.create_and_check_bert_for_attentions(*config_and_inputs)
|
||||
tester.create_and_check_bert_for_headmasking(*config_and_inputs)
|
||||
|
||||
@classmethod
|
||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
|
Loading…
Reference in New Issue
Block a user