mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix Seq2seqTrainer decoder attention mask (#26841)
Don't drop decoder_input_ids without also dropping decoder_attention_mask
This commit is contained in:
parent
280c757f6c
commit
34678db4a1
@ -288,7 +288,9 @@ class Seq2SeqTrainer(Trainer):
|
||||
and "decoder_input_ids" in generation_inputs
|
||||
and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
|
||||
):
|
||||
generation_inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
|
||||
generation_inputs = {
|
||||
k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
|
||||
}
|
||||
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
|
Loading…
Reference in New Issue
Block a user