mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
[EncoderDecoder] Fix Typo (#7915)
* fix encoder decoder models * add .gitignore
This commit is contained in:
parent
55bcd0cb59
commit
c912ba5f69
3
.gitignore
vendored
3
.gitignore
vendored
@ -157,3 +157,6 @@ debug.env
|
|||||||
|
|
||||||
#ctags
|
#ctags
|
||||||
tags
|
tags
|
||||||
|
|
||||||
|
# pre-commit
|
||||||
|
.pre-commit*
|
||||||
|
@ -434,8 +434,6 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
|
Loading…
Reference in New Issue
Block a user