Include decoder_attention_mask in T5 model inputs (#22835)

This commit is contained in:
Aashiq Muhamed 2023-04-20 07:05:36 -07:00 committed by GitHub
parent 91d6a593f1
commit 3b61d2890d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -1807,6 +1807,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
@ -1823,6 +1824,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}

View File

@ -1774,6 +1774,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
@ -1790,6 +1791,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}