diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 1316ff3ba99..ac3245a29c8 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -92,7 +92,11 @@ class HfArgumentParser(ArgumentParser): " the argument parser only supports one type per argument." f" Problem encountered in field '{field.name}'." ) - if bool not in field.type.__args__: + if type(None) not in field.type.__args__: + # filter `str` in Union + field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] + origin_type = getattr(field.type, "__origin__", field.type) + elif bool not in field.type.__args__: # filter `NoneType` in Union (except for `Union[bool, NoneType]`) field.type = ( field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5c370cf0720..f65125e348c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -20,7 +20,7 @@ import warnings from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from .debug_utils import DebugOption from .trainer_utils import ( @@ -493,7 +493,7 @@ class TrainingArguments: do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) - evaluation_strategy: IntervalStrategy = field( + evaluation_strategy: Union[IntervalStrategy, str] = field( default="no", metadata={"help": "The evaluation strategy to use."}, ) @@ -559,7 +559,7 @@ class TrainingArguments: default=-1, metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, ) - lr_scheduler_type: SchedulerType = field( + lr_scheduler_type: Union[SchedulerType, str] = field( default="linear", metadata={"help": "The scheduler type to use."}, ) @@ -596,14 +596,14 @@ class TrainingArguments: }, ) logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) - logging_strategy: IntervalStrategy = field( + logging_strategy: Union[IntervalStrategy, str] = field( default="steps", metadata={"help": "The logging strategy to use."}, ) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) - save_strategy: IntervalStrategy = field( + save_strategy: Union[IntervalStrategy, str] = field( default="steps", metadata={"help": "The checkpoint save strategy to use."}, ) @@ -815,7 +815,7 @@ class TrainingArguments: label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) - optim: OptimizerNames = field( + optim: Union[OptimizerNames, str] = field( default="adamw_hf", metadata={"help": "The optimizer to use."}, ) @@ -868,7 +868,7 @@ class TrainingArguments: hub_model_id: Optional[str] = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) - hub_strategy: HubStrategy = field( + hub_strategy: Union[HubStrategy, str] = field( default="every_save", metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, )