diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 144f9437335..5c01cbd0427 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -298,13 +298,13 @@ class Trainer: args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, - model_init: Callable[[], PreTrainedModel] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ): if args is None: output_dir = "tmp_trainer"