diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index e0b7fc214ec..2645f677f1c 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -639,6 +639,16 @@ def main(): result["gen_len"] = np.mean(prediction_lens) return result + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + training_args.generation_num_beams = ( + data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + ) + # Initialize our Trainer trainer = Seq2SeqTrainer( model=model, @@ -672,15 +682,9 @@ def main(): # Evaluation results = {} - max_length = ( - training_args.generation_max_length - if training_args.generation_max_length is not None - else data_args.val_max_target_length - ) - num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") + metrics = trainer.evaluate(metric_key_prefix="eval") max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) @@ -690,9 +694,7 @@ def main(): if training_args.do_predict: logger.info("*** Predict ***") - predict_results = trainer.predict( - predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams - ) + predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict") metrics = predict_results.metrics max_predict_samples = ( data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 8f669be72c5..b16a3fd0690 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -161,15 +161,6 @@ def parse_args(): "param of ``model.generate``, which is used during ``evaluate`` and ``predict``." ), ) - parser.add_argument( - "--max_length", - type=int, - default=128, - help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," - " sequences shorter will be padded if `--pad_to_max_lengh` is passed." - ), - ) parser.add_argument( "--num_beams", type=int, @@ -473,6 +464,9 @@ def main(): f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}" ) + if args.val_max_target_length is None: + args.val_max_target_length = args.max_target_length + # Temporarily set max_target_length for training. max_target_length = args.max_target_length padding = "max_length" if args.pad_to_max_length else False @@ -497,7 +491,7 @@ def main(): return model_inputs with accelerator.main_process_first(): - processed_datasets = raw_datasets.map( + train_dataset = raw_datasets["train"].map( preprocess_function, batched=True, num_proc=args.preprocessing_num_workers, @@ -506,8 +500,16 @@ def main(): desc="Running tokenizer on dataset", ) - train_dataset = processed_datasets["train"] - eval_dataset = processed_datasets["validation"] + # Temporarily set max_target_length for validation. + max_target_length = args.val_max_target_length + eval_dataset = raw_datasets["validation"].map( + preprocess_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 1): @@ -667,11 +669,9 @@ def main(): break model.eval() - if args.val_max_target_length is None: - args.val_max_target_length = args.max_target_length gen_kwargs = { - "max_length": args.val_max_target_length if args is not None else config.max_length, + "max_length": args.val_max_target_length, "num_beams": args.num_beams, } for step, batch in enumerate(eval_dataloader):