mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove and do the branching in
This commit is contained in:
parent
09cfd12235
commit
edfc8f8225
@ -282,53 +282,13 @@ class BertAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask)
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertDecoderAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertAttention, self).__init__()
|
||||
self.self = BertGeneralAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
self_outputs = self.self(query, key, value, attention_mask, head_mask)
|
||||
# in encoder-decoder attention we use the output of the previous decoder stage as the query
|
||||
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
|
||||
# This shows the limits of the current code architecture, which may benefit from some refactoring.
|
||||
attention_output = self.output(self_outputs[0], query)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user