mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Updated Trainer args typing (#20655)
This commit is contained in:
parent
3994c04585
commit
3ac040bca1
@ -298,13 +298,13 @@ class Trainer:
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
if args is None:
|
||||
output_dir = "tmp_trainer"
|
||||
|
Loading…
Reference in New Issue
Block a user