Fix a Bug, trainer_seq2seq.py, in the else branch at Line 172, generation_inputs should be a dict (#14546)

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()
This commit is contained in:
TranSirius 2021-12-08 01:09:18 +08:00 committed by GitHub
parent 2171695cc2
commit 39f1dff5a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -169,7 +169,7 @@ class Seq2SeqTrainer(Trainer):
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = inputs["input_ids"]
generation_inputs = {"input_ids": inputs["input_ids"]}
generated_tokens = self.model.generate(
**generation_inputs,