mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
[examples/summarization] deal with max_length
and num_beams
(#21740)
* Override the decoding parameters of Seq2SeqTrainer * Fix quality * Fix max_length parameter * Fix quality * Remove redundant parameter max_length * Separate the preprocess of train and validation to use different max_target_length
This commit is contained in:
parent
9ddf4f4f03
commit
3c0ce60855
@ -639,6 +639,16 @@ def main():
|
|||||||
result["gen_len"] = np.mean(prediction_lens)
|
result["gen_len"] = np.mean(prediction_lens)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
|
training_args.generation_max_length = (
|
||||||
|
training_args.generation_max_length
|
||||||
|
if training_args.generation_max_length is not None
|
||||||
|
else data_args.val_max_target_length
|
||||||
|
)
|
||||||
|
training_args.generation_num_beams = (
|
||||||
|
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -672,15 +682,9 @@ def main():
|
|||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
max_length = (
|
|
||||||
training_args.generation_max_length
|
|
||||||
if training_args.generation_max_length is not None
|
|
||||||
else data_args.val_max_target_length
|
|
||||||
)
|
|
||||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
@ -690,9 +694,7 @@ def main():
|
|||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
predict_results = trainer.predict(
|
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||||
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
|
|
||||||
)
|
|
||||||
metrics = predict_results.metrics
|
metrics = predict_results.metrics
|
||||||
max_predict_samples = (
|
max_predict_samples = (
|
||||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
@ -161,15 +161,6 @@ def parse_args():
|
|||||||
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--max_length",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help=(
|
|
||||||
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
|
||||||
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_beams",
|
"--num_beams",
|
||||||
type=int,
|
type=int,
|
||||||
@ -473,6 +464,9 @@ def main():
|
|||||||
f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}"
|
f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.val_max_target_length is None:
|
||||||
|
args.val_max_target_length = args.max_target_length
|
||||||
|
|
||||||
# Temporarily set max_target_length for training.
|
# Temporarily set max_target_length for training.
|
||||||
max_target_length = args.max_target_length
|
max_target_length = args.max_target_length
|
||||||
padding = "max_length" if args.pad_to_max_length else False
|
padding = "max_length" if args.pad_to_max_length else False
|
||||||
@ -497,7 +491,7 @@ def main():
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
processed_datasets = raw_datasets.map(
|
train_dataset = raw_datasets["train"].map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
@ -506,8 +500,16 @@ def main():
|
|||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = processed_datasets["train"]
|
# Temporarily set max_target_length for validation.
|
||||||
eval_dataset = processed_datasets["validation"]
|
max_target_length = args.val_max_target_length
|
||||||
|
eval_dataset = raw_datasets["validation"].map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on dataset",
|
||||||
|
)
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
for index in random.sample(range(len(train_dataset)), 1):
|
for index in random.sample(range(len(train_dataset)), 1):
|
||||||
@ -667,11 +669,9 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
if args.val_max_target_length is None:
|
|
||||||
args.val_max_target_length = args.max_target_length
|
|
||||||
|
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"max_length": args.val_max_target_length if args is not None else config.max_length,
|
"max_length": args.val_max_target_length,
|
||||||
"num_beams": args.num_beams,
|
"num_beams": args.num_beams,
|
||||||
}
|
}
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
|
Loading…
Reference in New Issue
Block a user