mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Unifying training argument type annotations (#17934)
* doc: Unify training arg type annotations * wip: extracting enum type from Union * blackening
This commit is contained in:
parent
205bc4152c
commit
4f8361afe7
@ -92,7 +92,11 @@ class HfArgumentParser(ArgumentParser):
|
||||
" the argument parser only supports one type per argument."
|
||||
f" Problem encountered in field '{field.name}'."
|
||||
)
|
||||
if bool not in field.type.__args__:
|
||||
if type(None) not in field.type.__args__:
|
||||
# filter `str` in Union
|
||||
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
|
||||
origin_type = getattr(field.type, "__origin__", field.type)
|
||||
elif bool not in field.type.__args__:
|
||||
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
|
||||
field.type = (
|
||||
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
|
||||
|
@ -20,7 +20,7 @@ import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .debug_utils import DebugOption
|
||||
from .trainer_utils import (
|
||||
@ -493,7 +493,7 @@ class TrainingArguments:
|
||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||
evaluation_strategy: IntervalStrategy = field(
|
||||
evaluation_strategy: Union[IntervalStrategy, str] = field(
|
||||
default="no",
|
||||
metadata={"help": "The evaluation strategy to use."},
|
||||
)
|
||||
@ -559,7 +559,7 @@ class TrainingArguments:
|
||||
default=-1,
|
||||
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
|
||||
)
|
||||
lr_scheduler_type: SchedulerType = field(
|
||||
lr_scheduler_type: Union[SchedulerType, str] = field(
|
||||
default="linear",
|
||||
metadata={"help": "The scheduler type to use."},
|
||||
)
|
||||
@ -596,14 +596,14 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
|
||||
logging_strategy: IntervalStrategy = field(
|
||||
logging_strategy: Union[IntervalStrategy, str] = field(
|
||||
default="steps",
|
||||
metadata={"help": "The logging strategy to use."},
|
||||
)
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
|
||||
save_strategy: IntervalStrategy = field(
|
||||
save_strategy: Union[IntervalStrategy, str] = field(
|
||||
default="steps",
|
||||
metadata={"help": "The checkpoint save strategy to use."},
|
||||
)
|
||||
@ -815,7 +815,7 @@ class TrainingArguments:
|
||||
label_smoothing_factor: float = field(
|
||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||
)
|
||||
optim: OptimizerNames = field(
|
||||
optim: Union[OptimizerNames, str] = field(
|
||||
default="adamw_hf",
|
||||
metadata={"help": "The optimizer to use."},
|
||||
)
|
||||
@ -868,7 +868,7 @@ class TrainingArguments:
|
||||
hub_model_id: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||
)
|
||||
hub_strategy: HubStrategy = field(
|
||||
hub_strategy: Union[HubStrategy, str] = field(
|
||||
default="every_save",
|
||||
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user