From e16cbe88ae502787512f4d79645cac2f919bada9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 Mar 2023 19:00:25 +0000 Subject: [PATCH] Trainer: let generate pick its inputs (#22108) * Let generate pick its inputs * fix squad seq2seq example --- src/transformers/trainer_seq2seq.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 4a79516d265..3f7fb821181 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -182,23 +182,11 @@ class Seq2SeqTrainer(Trainer): gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus ) - if "attention_mask" in inputs: - gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) - if "global_attention_mask" in inputs: - gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) + # TODO (Joao): the following line is needed to keep a consistent result on SQUAD. Ideally, we should not block + # users from preparing a dataset with `decoder_input_ids`. + inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} + generated_tokens = self.model.generate(**inputs, **gen_kwargs) - # prepare generation inputs - # some encoder-decoder models can have varying encoder's and thus - # varying model input names - if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: - generation_inputs = inputs[self.model.encoder.main_input_name] - else: - generation_inputs = inputs[self.model.main_input_name] - - generated_tokens = self.model.generate( - generation_inputs, - **gen_kwargs, - ) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # TODO: remove this hack when the legacy code that initializes generation_config from a model config is # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183