diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 07d0571e44c..5cd89b3701c 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -64,6 +64,7 @@ class Seq2SeqTrainer(Trainer): Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"] ] = None, model_init: Optional[Callable[[], "PreTrainedModel"]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, callbacks: Optional[List["TrainerCallback"]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), @@ -77,6 +78,7 @@ class Seq2SeqTrainer(Trainer): eval_dataset=eval_dataset, processing_class=processing_class, model_init=model_init, + compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers,