mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Updated Trainer args typing (#20655)
This commit is contained in:
parent
3994c04585
commit
3ac040bca1
@ -298,13 +298,13 @@ class Trainer:
|
|||||||
args: TrainingArguments = None,
|
args: TrainingArguments = None,
|
||||||
data_collator: Optional[DataCollator] = None,
|
data_collator: Optional[DataCollator] = None,
|
||||||
train_dataset: Optional[Dataset] = None,
|
train_dataset: Optional[Dataset] = None,
|
||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
model_init: Callable[[], PreTrainedModel] = None,
|
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
callbacks: Optional[List[TrainerCallback]] = None,
|
callbacks: Optional[List[TrainerCallback]] = None,
|
||||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
if args is None:
|
if args is None:
|
||||||
output_dir = "tmp_trainer"
|
output_dir = "tmp_trainer"
|
||||||
|
Loading…
Reference in New Issue
Block a user