train args defaulting None marked as Optional (#17156)

Co-authored-by: Dom Miketa <dmiketa@exscientia.co.uk>
This commit is contained in:
Dom Miketa 2022-05-10 15:09:34 +01:00 committed by GitHub
parent 6d80c92c77
commit 1766fa2159
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 13 deletions

View File

@ -582,7 +582,7 @@ class TrainingArguments:
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
bf16: bool = field(
default=False,
metadata={
@ -616,14 +616,14 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
tf32: bool = field(
tf32: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
xpu_backend: Optional[str] = field(
default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
)
@ -648,7 +648,7 @@ class TrainingArguments:
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field(
default=0,
metadata={
@ -770,14 +770,14 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
hub_model_id: str = field(
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field(
default=False,
@ -793,13 +793,15 @@ class TrainingArguments:
default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
)
push_to_hub_model_id: str = field(
push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
push_to_hub_organization: str = field(
push_to_hub_organization: Optional[str] = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
)
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
push_to_hub_token: Optional[str] = field(
default=None, metadata={"help": "The token to use to push to the Model Hub."}
)
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",

View File

@ -14,7 +14,7 @@
import warnings
from dataclasses import dataclass, field
from typing import Tuple
from typing import Optional, Tuple
from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not.
"""
tpu_name: str = field(
tpu_name: Optional[str] = field(
default=None,
metadata={"help": "Name of TPU"},
)
tpu_zone: str = field(
tpu_zone: Optional[str] = field(
default=None,
metadata={"help": "Zone of TPU"},
)
gcp_project: str = field(
gcp_project: Optional[str] = field(
default=None,
metadata={"help": "Name of Cloud TPU-enabled project"},
)