mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Include decoder_attention_mask in T5 model inputs (#22835)
This commit is contained in:
parent
91d6a593f1
commit
3b61d2890d
@ -1807,6 +1807,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
@ -1823,6 +1824,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
@ -1774,6 +1774,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
@ -1790,6 +1791,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user