Unifying training argument type annotations (#17934)

* doc: Unify training arg type annotations

* wip: extracting enum type from Union

* blackening
This commit is contained in:
Jannis Born 2022-06-30 14:53:32 +02:00 committed by GitHub
parent 205bc4152c
commit 4f8361afe7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 8 deletions

View File

@ -92,7 +92,11 @@ class HfArgumentParser(ArgumentParser):
" the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'."
)
if bool not in field.type.__args__:
if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
field.type = (
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]

View File

@ -20,7 +20,7 @@ import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from .debug_utils import DebugOption
from .trainer_utils import (
@ -493,7 +493,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: IntervalStrategy = field(
evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
@ -559,7 +559,7 @@ class TrainingArguments:
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
lr_scheduler_type: SchedulerType = field(
lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
@ -596,14 +596,14 @@ class TrainingArguments:
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_strategy: IntervalStrategy = field(
logging_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: IntervalStrategy = field(
save_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
@ -815,7 +815,7 @@ class TrainingArguments:
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
optim: OptimizerNames = field(
optim: Union[OptimizerNames, str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
@ -868,7 +868,7 @@ class TrainingArguments:
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
hub_strategy: Union[HubStrategy, str] = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)