Remove and do the branching in

This commit is contained in:
Rémi Louf 2019-10-10 10:17:27 +02:00
parent 09cfd12235
commit edfc8f8225

View File

@ -282,53 +282,13 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertDecoderAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertGeneralAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, attention_mask=None, head_mask=None):
self_outputs = self.self(query, key, value, attention_mask, head_mask)
# in encoder-decoder attention we use the output of the previous decoder stage as the query
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
# This shows the limits of the current code architecture, which may benefit from some refactoring.
attention_output = self.output(self_outputs[0], query)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()