Rebase on master + DistilBERT head pruning patch

This commit is contained in:
LysandreJik 2019-08-31 00:37:41 -04:00
parent b6992b7b47
commit 11600edc6e

View File

@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module):
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.pruned_heads = set()
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module):
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask = None):
"""
@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
def _init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, nn.Embedding):
@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self.apply(self.init_weights)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids,
@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids,