From 3ac040bca1efbf5cfe9604a5b2a10a5392917c20 Mon Sep 17 00:00:00 2001 From: Julian Mack <32888280+julianmack@users.noreply.github.com> Date: Wed, 7 Dec 2022 14:57:39 +0000 Subject: [PATCH] Updated Trainer args typing (#20655) --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 144f9437335..5c01cbd0427 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -298,13 +298,13 @@ class Trainer: args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, - model_init: Callable[[], PreTrainedModel] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ): if args is None: output_dir = "tmp_trainer"