Add compute_loss_func to Seq2SeqTrainer (#35136)

This commit is contained in:
Cheng-Han Chiang 2024-12-29 22:01:35 +08:00 committed by GitHub
parent 90f256c90c
commit 5cabc75b4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,6 +64,7 @@ class Seq2SeqTrainer(Trainer):
Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
] = None,
model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
callbacks: Optional[List["TrainerCallback"]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
@ -77,6 +78,7 @@ class Seq2SeqTrainer(Trainer):
eval_dataset=eval_dataset,
processing_class=processing_class,
model_init=model_init,
compute_loss_func=compute_loss_func,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,