mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Improve typing in TrainingArgument (#36944)
* Better error message in TrainingArgument typing checks * Better typing * Small fixes Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
174684a9b6
commit
ae3e4e2d97
@ -907,7 +907,7 @@ class TrainingArguments:
|
|||||||
default="linear",
|
default="linear",
|
||||||
metadata={"help": "The scheduler type to use."},
|
metadata={"help": "The scheduler type to use."},
|
||||||
)
|
)
|
||||||
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
|
lr_scheduler_kwargs: Optional[Union[dict[str, Any], str]] = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@ -1230,7 +1230,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
fsdp_config: Optional[Union[dict, str]] = field(
|
fsdp_config: Optional[Union[dict[str, Any], str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@ -1366,7 +1366,7 @@ class TrainingArguments:
|
|||||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
|
gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
|
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
|
||||||
@ -1451,7 +1451,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ddp_timeout: Optional[int] = field(
|
ddp_timeout: int = field(
|
||||||
default=1800,
|
default=1800,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
|
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
|
||||||
@ -1667,7 +1667,7 @@ class TrainingArguments:
|
|||||||
) and self.metric_for_best_model is None:
|
) and self.metric_for_best_model is None:
|
||||||
self.metric_for_best_model = "loss"
|
self.metric_for_best_model = "loss"
|
||||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||||
self.greater_is_better = not (self.metric_for_best_model.endswith("loss"))
|
self.greater_is_better = not self.metric_for_best_model.endswith("loss")
|
||||||
if self.run_name is None:
|
if self.run_name is None:
|
||||||
self.run_name = self.output_dir
|
self.run_name = self.output_dir
|
||||||
if self.framework == "pt" and is_torch_available():
|
if self.framework == "pt" and is_torch_available():
|
||||||
@ -2140,7 +2140,7 @@ class TrainingArguments:
|
|||||||
f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||||
)
|
)
|
||||||
# We delay the init of `PartialState` to the end for clarity
|
# We delay the init of `PartialState` to the end for clarity
|
||||||
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
|
accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False}
|
||||||
if isinstance(self.accelerator_config, AcceleratorConfig):
|
if isinstance(self.accelerator_config, AcceleratorConfig):
|
||||||
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
|
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
|
||||||
"use_configured_state", False
|
"use_configured_state", False
|
||||||
|
@ -446,7 +446,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(raw_dict_fields),
|
len(raw_dict_fields),
|
||||||
0,
|
0,
|
||||||
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
|
f"Found invalid raw `dict` types in the `TrainingArgument` typings, which are {raw_dict_fields}. "
|
||||||
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
|
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user