diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index 6abb41b33fe..bdf82bda9f3 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -46,12 +46,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): **gen_kwargs, ) -> Dict[str, float]: gen_kwargs = gen_kwargs.copy() - gen_kwargs["max_length"] = ( - gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length - ) - gen_kwargs["num_beams"] = ( - gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams - ) + + # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the + # training args + if gen_kwargs.get("max_length") is None and self.args.generation_max_length is not None: + gen_kwargs["max_length"] = self.args.generation_max_length + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams self._gen_kwargs = gen_kwargs eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 569d939fdcf..bfba0d59436 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -149,11 +149,17 @@ class Seq2SeqTrainer(Trainer): """ gen_kwargs = gen_kwargs.copy() - if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + + # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the + # training args + if ( + gen_kwargs.get("max_length") is None + and gen_kwargs.get("max_new_tokens") is None + and self.args.generation_max_length is not None + ): gen_kwargs["max_length"] = self.args.generation_max_length - gen_kwargs["num_beams"] = ( - gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams - ) + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams self._gen_kwargs = gen_kwargs return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) @@ -206,11 +212,17 @@ class Seq2SeqTrainer(Trainer): """ gen_kwargs = gen_kwargs.copy() - if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + + # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the + # training args + if ( + gen_kwargs.get("max_length") is None + and gen_kwargs.get("max_new_tokens") is None + and self.args.generation_max_length is not None + ): gen_kwargs["max_length"] = self.args.generation_max_length - gen_kwargs["num_beams"] = ( - gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams - ) + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams self._gen_kwargs = gen_kwargs return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) @@ -256,16 +268,14 @@ class Seq2SeqTrainer(Trainer): # XXX: adapt synced_gpus for fairscale as well # Priority (handled in generate): - # gen_kwargs > model.generation_config > default GenerationConfig() - + # non-`None` gen_kwargs > model.generation_config > default GenerationConfig() if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): gen_kwargs = self._gen_kwargs.copy() + if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None: + gen_kwargs.pop("num_beams") + if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None: + gen_kwargs.pop("max_length") - if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: - gen_kwargs["max_length"] = self.model.config.max_length - gen_kwargs["num_beams"] = ( - gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams - ) default_synced_gpus = True if is_deepspeed_zero3_enabled() else False gen_kwargs["synced_gpus"] = ( gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus