mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
generalize BertSelfAttention to take separate query, key, value
There is currently no way to specify the quey, key and value separately in the Attention module. However, the decoder's "encoder-decoder attention" layers take the decoder's last output as a query, the encoder's states as key and value. We thus modify the existing code so query, key and value can be added separately. This obviously poses some naming conventions; `BertSelfAttention` is not a self-attention module anymore. The way the residual is forwarded is now awkard, etc. We will need to do some refacto once the decoder is fully implemented.
This commit is contained in:
parent
31adbb247c
commit
a0dcefa382
@ -198,10 +198,10 @@ class BertSelfAttention(nn.Module):
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
def forward(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
mixed_query_layer = self.query(query)
|
||||
mixed_key_layer = self.key(key)
|
||||
mixed_value_layer = self.value(value)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
@ -279,9 +279,12 @@ class BertAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, input_tensor, attention_mask=None, head_mask=None):
|
||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||
attention_output = self.output(self_outputs[0], input_tensor)
|
||||
def forward(self, query_tensor, key_tensor, value_tensor, attention_mask=None, head_mask=None):
|
||||
self_outputs = self.self(query_tensor, key_tensor, value_tensor, attention_mask, head_mask)
|
||||
# in encoder-decoder attention we use the output of the previous decoder stage as the query
|
||||
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
|
||||
# This shows the limits of the current code architecture, which may benefit from some refactoring.
|
||||
attention_output = self.output(self_outputs[0], query_tensor)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
@ -323,7 +326,11 @@ class BertEncoderLayer(nn.Module):
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_outputs = self.attention(query_tensor=hidden_states,
|
||||
key_tensor=hidden_states,
|
||||
value_tensor=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
@ -333,6 +340,7 @@ class BertEncoderLayer(nn.Module):
|
||||
|
||||
class BertDecoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoderLayer, self).__init__()
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, hidden_state, encoder_output):
|
||||
|
Loading…
Reference in New Issue
Block a user