diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 02c57f45026..dcbfe916394 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -93,19 +93,27 @@ class DataTrainingArguments: "than this will be truncated, sequences shorter will be padded." }, ) - max_length: Optional[int] = field( + max_target_length: Optional[int] = field( default=128, metadata={ "help": "The maximum total sequence length for target text after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) - eval_max_length: Optional[int] = field( + val_max_target_length: Optional[int] = field( default=142, metadata={ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. " + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + test_max_target_length: Optional[int] = field( + default=142, + metadata={ + "help": "The maximum total sequence length for test target text after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." - " This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``" }, ) n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."}) @@ -233,7 +241,7 @@ def main(): type_path="train", data_dir=data_args.data_dir, n_obs=data_args.n_train, - max_target_length=data_args.max_length, + max_target_length=data_args.max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) @@ -246,7 +254,7 @@ def main(): type_path="val", data_dir=data_args.data_dir, n_obs=data_args.n_val, - max_target_length=data_args.eval_max_length, + max_target_length=data_args.val_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) @@ -259,7 +267,7 @@ def main(): type_path="test", data_dir=data_args.data_dir, n_obs=data_args.n_test, - max_target_length=data_args.eval_max_length, + max_target_length=data_args.test_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) @@ -310,7 +318,7 @@ def main(): logger.info("*** Evaluate ***") metrics = trainer.evaluate( - metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams + metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams ) metrics["val_n_objs"] = data_args.n_val metrics["val_loss"] = round(metrics["val_loss"], 4) @@ -326,7 +334,7 @@ def main(): test_output = trainer.predict( test_dataset=test_dataset, metric_key_prefix="test", - max_length=data_args.eval_max_length, + max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams, ) metrics = test_output.metrics diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 9ce347ed894..e8f7d431b6c 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -137,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus): --n_train 8 --n_val 8 --max_source_length {max_len} - --max_length {max_len} - --eval_max_length {max_len} + --max_target_length {max_len} + --val_max_target_length {max_len} --do_train --do_eval --do_predict diff --git a/examples/seq2seq/train_distil_marian_enro.sh b/examples/seq2seq/train_distil_marian_enro.sh index 78bc6776cb9..82f7725003e 100644 --- a/examples/seq2seq/train_distil_marian_enro.sh +++ b/examples/seq2seq/train_distil_marian_enro.sh @@ -29,7 +29,8 @@ python finetune_trainer.py \ --freeze_encoder --freeze_embeds \ --num_train_epochs=6 \ --save_steps 3000 --eval_steps 3000 \ - --max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \ + --max_source_length $MAX_LEN --max_target_length $MAX_LEN \ + --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \ --do_train --do_eval --do_predict \ --evaluation_strategy steps \ --predict_with_generate --logging_first_step \ diff --git a/examples/seq2seq/train_distil_marian_enro_tpu.sh b/examples/seq2seq/train_distil_marian_enro_tpu.sh index 7239d83a770..82f6ce1406e 100644 --- a/examples/seq2seq/train_distil_marian_enro_tpu.sh +++ b/examples/seq2seq/train_distil_marian_enro_tpu.sh @@ -30,7 +30,8 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \ --num_train_epochs=6 \ --save_steps 500 --eval_steps 500 \ --logging_first_step --logging_steps 200 \ - --max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \ + --max_source_length $MAX_LEN --max_target_length $MAX_LEN \ + --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \ --do_train --do_eval \ --evaluation_strategy steps \ --prediction_loss_only \ diff --git a/examples/seq2seq/train_distilbart_cnn.sh b/examples/seq2seq/train_distilbart_cnn.sh index 70b4ff9bf09..ec0aec8e597 100644 --- a/examples/seq2seq/train_distilbart_cnn.sh +++ b/examples/seq2seq/train_distilbart_cnn.sh @@ -32,7 +32,7 @@ python finetune_trainer.py \ --num_train_epochs=2 \ --save_steps 3000 --eval_steps 3000 \ --logging_first_step \ - --max_length 56 --eval_max_length $MAX_TGT_LEN \ + --max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN\ --do_train --do_eval --do_predict \ --evaluation_strategy steps \ --predict_with_generate --sortish_sampler \ diff --git a/examples/seq2seq/train_mbart_cc25_enro.sh b/examples/seq2seq/train_mbart_cc25_enro.sh index cfce15c1ec0..2b603eda7c3 100644 --- a/examples/seq2seq/train_mbart_cc25_enro.sh +++ b/examples/seq2seq/train_mbart_cc25_enro.sh @@ -24,7 +24,7 @@ python finetune_trainer.py \ --src_lang en_XX --tgt_lang ro_RO \ --freeze_embeds \ --per_device_train_batch_size=4 --per_device_eval_batch_size=4 \ - --max_source_length 128 --max_length 128 --eval_max_length 128 \ + --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --test_max_target_length 128\ --sortish_sampler \ --num_train_epochs 6 \ --save_steps 25000 --eval_steps 25000 --logging_steps 1000 \ diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 0b74bfd57fc..437cdf2e632 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -330,7 +330,7 @@ class Seq2SeqDataCollator: [x["src_texts"] for x in batch], tgt_texts=[x["tgt_texts"] for x in batch], max_length=self.data_args.max_source_length, - max_target_length=self.data_args.max_length, + max_target_length=self.data_args.max_target_length, padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack return_tensors="pt", **self.dataset_kwargs,