[Pix2Struct] Fix pix2struct cross attention (#25200)

* fix pix2struct cross attention

* fix torchscript slow test
This commit is contained in:
Younes Belkada 2023-08-01 10:56:37 +02:00 committed by GitHub
parent 4033ea7167
commit 77c3973e8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1547,8 +1547,9 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
all_cross_attentions = all_cross_attentions + (layer_outputs[3],)
all_attentions = all_attentions + (layer_outputs[3],)
if encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)