[examples/summarization] deal with max_length and num_beams (#21740)

* Override the decoding parameters of Seq2SeqTrainer

* Fix quality

* Fix max_length parameter

* Fix quality

* Remove redundant parameter max_length

* Separate the preprocess of train and validation to use different max_target_length
This commit is contained in:
bofeng huang 2023-02-27 08:18:14 +01:00 committed by GitHub
parent 9ddf4f4f03
commit 3c0ce60855
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 25 deletions

View File

@ -639,6 +639,16 @@ def main():
result["gen_len"] = np.mean(prediction_lens)
return result
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
@ -672,15 +682,9 @@ def main():
# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
metrics = trainer.evaluate(metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
@ -690,9 +694,7 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
predict_results = trainer.predict(
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)

View File

@ -161,15 +161,6 @@ def parse_args():
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--max_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
),
)
parser.add_argument(
"--num_beams",
type=int,
@ -473,6 +464,9 @@ def main():
f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}"
)
if args.val_max_target_length is None:
args.val_max_target_length = args.max_target_length
# Temporarily set max_target_length for training.
max_target_length = args.max_target_length
padding = "max_length" if args.pad_to_max_length else False
@ -497,7 +491,7 @@ def main():
return model_inputs
with accelerator.main_process_first():
processed_datasets = raw_datasets.map(
train_dataset = raw_datasets["train"].map(
preprocess_function,
batched=True,
num_proc=args.preprocessing_num_workers,
@ -506,8 +500,16 @@ def main():
desc="Running tokenizer on dataset",
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]
# Temporarily set max_target_length for validation.
max_target_length = args.val_max_target_length
eval_dataset = raw_datasets["validation"].map(
preprocess_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 1):
@ -667,11 +669,9 @@ def main():
break
model.eval()
if args.val_max_target_length is None:
args.val_max_target_length = args.max_target_length
gen_kwargs = {
"max_length": args.val_max_target_length if args is not None else config.max_length,
"max_length": args.val_max_target_length,
"num_beams": args.num_beams,
}
for step, batch in enumerate(eval_dataloader):