mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add compute_loss_func to Seq2SeqTrainer (#35136)
This commit is contained in:
parent
90f256c90c
commit
5cabc75b4b
@ -64,6 +64,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
|
||||
compute_loss_func: Optional[Callable] = None,
|
||||
compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
@ -77,6 +78,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_loss_func=compute_loss_func,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
|
Loading…
Reference in New Issue
Block a user