the decoder attends to the output of the encoder stack (last layer)

This commit is contained in:
Rémi Louf 2019-10-17 15:21:46 +02:00
parent 56e2ee4ead
commit f873a3edb2
2 changed files with 7 additions and 8 deletions

View File

@ -288,8 +288,8 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
@ -334,13 +334,13 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_state is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
@ -364,8 +364,7 @@ class BertEncoder(nn.Module):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
encoder_hidden_state = encoder_hidden_states[i]
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
hidden_states = layer_outputs[0]
if self.output_attentions:

View File

@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
else:
encoder_outputs = ()