[T5] Fix Cross Attention position bias (#4499)

* fix

* fix1
This commit is contained in:
ZhuBaohe 2020-05-26 20:57:24 +08:00 committed by GitHub
parent 1d69028989
commit a163c9ca5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -745,7 +745,7 @@ class T5Stack(T5PreTrainedModel):
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)

View File

@ -682,7 +682,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)