diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index 3b4ae5ef03a..70a635cc50b 100755 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -167,9 +167,22 @@ class DataTrainingArguments: "value if set." }, ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) - eval_beams: Optional[int] = field(default=None, metadata={"help": "Number of beams to use for evaluation."}) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) ignore_pad_token_for_loss: bool = field( default=True, metadata={ @@ -336,8 +349,13 @@ def main(): # We need to tokenize inputs and targets. if training_args.do_train: column_names = datasets["train"].column_names - else: + elif training_args.do_eval: column_names = datasets["validation"].column_names + elif training_args.do_predict: + column_names = datasets["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return # For translation we set the codes of our source and target languages (only useful for mBART, the others will # ignore those attributes). @@ -440,6 +458,19 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + test_dataset = datasets["test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + test_dataset = test_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + # Data collator label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id if data_args.pad_to_max_length: @@ -523,7 +554,7 @@ def main(): if training_args.do_eval: logger.info("*** Evaluate ***") - results = trainer.evaluate() + results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams) output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt") if trainer.is_world_process_zero(): @@ -533,6 +564,34 @@ def main(): logger.info(f" {key} = {value}") writer.write(f"{key} = {value}\n") + if training_args.do_predict: + logger.info("*** Test ***") + + test_results = trainer.predict( + test_dataset, + metric_key_prefix="test", + max_length=data_args.val_max_target_length, + num_beams=data_args.num_beams, + ) + test_metrics = test_results.metrics + + output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt") + if trainer.is_world_process_zero(): + with open(output_test_result_file, "w") as writer: + logger.info("***** Test results *****") + for key, value in sorted(test_metrics.items()): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + if training_args.predict_with_generate: + test_preds = tokenizer.batch_decode( + test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + test_preds = [pred.strip() for pred in test_preds] + output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt") + with open(output_test_preds_file, "w") as writer: + writer.write("\n".join(test_preds)) + return results