mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
the decoder attends to the output of the encoder stack (last layer)
This commit is contained in:
parent
56e2ee4ead
commit
f873a3edb2
@ -288,8 +288,8 @@ class BertAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
@ -334,13 +334,13 @@ class BertLayer(nn.Module):
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
if self.is_decoder and encoder_hidden_state is not None:
|
||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
@ -364,8 +364,7 @@ class BertEncoder(nn.Module):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
encoder_hidden_state = encoder_hidden_states[i]
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
|
@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user