mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
train args defaulting None marked as Optional (#17156)
Co-authored-by: Dom Miketa <dmiketa@exscientia.co.uk>
This commit is contained in:
parent
6d80c92c77
commit
1766fa2159
@ -582,7 +582,7 @@ class TrainingArguments:
|
||||
)
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
||||
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||
data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
|
||||
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -616,14 +616,14 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
|
||||
)
|
||||
tf32: bool = field(
|
||||
tf32: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
|
||||
},
|
||||
)
|
||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||
xpu_backend: str = field(
|
||||
xpu_backend: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
|
||||
)
|
||||
@ -648,7 +648,7 @@ class TrainingArguments:
|
||||
dataloader_drop_last: bool = field(
|
||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||
)
|
||||
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||
eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||
dataloader_num_workers: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
@ -770,14 +770,14 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
||||
)
|
||||
hub_model_id: str = field(
|
||||
hub_model_id: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||
)
|
||||
hub_strategy: HubStrategy = field(
|
||||
default="every_save",
|
||||
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
|
||||
)
|
||||
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
|
||||
gradient_checkpointing: bool = field(
|
||||
default=False,
|
||||
@ -793,13 +793,15 @@ class TrainingArguments:
|
||||
default="auto",
|
||||
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
|
||||
)
|
||||
push_to_hub_model_id: str = field(
|
||||
push_to_hub_model_id: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
||||
)
|
||||
push_to_hub_organization: str = field(
|
||||
push_to_hub_organization: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
|
||||
)
|
||||
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
push_to_hub_token: Optional[str] = field(
|
||||
default=None, metadata={"help": "The token to use to push to the Model Hub."}
|
||||
)
|
||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||
mp_parameters: str = field(
|
||||
default="",
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import cached_property, is_tf_available, logging, tf_required
|
||||
@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
|
||||
Whether to activate the XLA compilation or not.
|
||||
"""
|
||||
|
||||
tpu_name: str = field(
|
||||
tpu_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of TPU"},
|
||||
)
|
||||
|
||||
tpu_zone: str = field(
|
||||
tpu_zone: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Zone of TPU"},
|
||||
)
|
||||
|
||||
gcp_project: str = field(
|
||||
gcp_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of Cloud TPU-enabled project"},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user