mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[examples/summarization] deal with max_length
and num_beams
(#21740)
* Override the decoding parameters of Seq2SeqTrainer * Fix quality * Fix max_length parameter * Fix quality * Remove redundant parameter max_length * Separate the preprocess of train and validation to use different max_target_length
This commit is contained in:
parent
9ddf4f4f03
commit
3c0ce60855
@ -639,6 +639,16 @@ def main():
|
||||
result["gen_len"] = np.mean(prediction_lens)
|
||||
return result
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
training_args.generation_num_beams = (
|
||||
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
@ -672,15 +682,9 @@ def main():
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
@ -690,9 +694,7 @@ def main():
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
|
||||
)
|
||||
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
|
@ -161,15 +161,6 @@ def parse_args():
|
||||
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=128,
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
||||
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
@ -473,6 +464,9 @@ def main():
|
||||
f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
|
||||
if args.val_max_target_length is None:
|
||||
args.val_max_target_length = args.max_target_length
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = args.max_target_length
|
||||
padding = "max_length" if args.pad_to_max_length else False
|
||||
@ -497,7 +491,7 @@ def main():
|
||||
return model_inputs
|
||||
|
||||
with accelerator.main_process_first():
|
||||
processed_datasets = raw_datasets.map(
|
||||
train_dataset = raw_datasets["train"].map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
@ -506,8 +500,16 @@ def main():
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
train_dataset = processed_datasets["train"]
|
||||
eval_dataset = processed_datasets["validation"]
|
||||
# Temporarily set max_target_length for validation.
|
||||
max_target_length = args.val_max_target_length
|
||||
eval_dataset = raw_datasets["validation"].map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 1):
|
||||
@ -667,11 +669,9 @@ def main():
|
||||
break
|
||||
|
||||
model.eval()
|
||||
if args.val_max_target_length is None:
|
||||
args.val_max_target_length = args.max_target_length
|
||||
|
||||
gen_kwargs = {
|
||||
"max_length": args.val_max_target_length if args is not None else config.max_length,
|
||||
"max_length": args.val_max_target_length,
|
||||
"num_beams": args.num_beams,
|
||||
}
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
|
Loading…
Reference in New Issue
Block a user