generalize BertSelfAttention to take separate query, key, value

There is currently no way to specify the quey, key and value separately
in the Attention module. However, the decoder's "encoder-decoder
attention" layers take the decoder's last output as a query, the
encoder's states as key and value. We thus modify the existing code so
query, key and value can be added separately.

This obviously poses some naming conventions; `BertSelfAttention` is not
a self-attention module anymore. The way the residual is forwarded is
now awkard, etc. We will need to do some refacto once the decoder is
fully implemented.
This commit is contained in:
Rémi Louf 2019-10-07 17:53:58 +02:00
parent 31adbb247c
commit a0dcefa382

View File

@ -198,10 +198,10 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
def forward(self, query, key, value, attention_mask=None, head_mask=None):
mixed_query_layer = self.query(query)
mixed_key_layer = self.key(key)
mixed_value_layer = self.value(value)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
@ -279,9 +279,12 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_tensor, attention_mask=None, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor)
def forward(self, query_tensor, key_tensor, value_tensor, attention_mask=None, head_mask=None):
self_outputs = self.self(query_tensor, key_tensor, value_tensor, attention_mask, head_mask)
# in encoder-decoder attention we use the output of the previous decoder stage as the query
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
# This shows the limits of the current code architecture, which may benefit from some refactoring.
attention_output = self.output(self_outputs[0], query_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
@ -323,7 +326,11 @@ class BertEncoderLayer(nn.Module):
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_outputs = self.attention(query_tensor=hidden_states,
key_tensor=hidden_states,
value_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
@ -333,6 +340,7 @@ class BertEncoderLayer(nn.Module):
class BertDecoderLayer(nn.Module):
def __init__(self, config):
super(BertDecoderLayer, self).__init__()
raise NotImplementedError
def forward(self, hidden_state, encoder_output):