Improve typing in TrainingArgument (#36944)

* Better error message in TrainingArgument typing checks

* Better typing

* Small fixes

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
Yuanyuan Chen 2025-05-21 21:54:38 +08:00 committed by GitHub
parent 174684a9b6
commit ae3e4e2d97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 8 deletions

View File

@ -907,7 +907,7 @@ class TrainingArguments:
default="linear",
metadata={"help": "The scheduler type to use."},
)
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
lr_scheduler_kwargs: Optional[Union[dict[str, Any], str]] = field(
default_factory=dict,
metadata={
"help": (
@ -1230,11 +1230,11 @@ class TrainingArguments:
)
},
)
fsdp_config: Optional[Union[dict, str]] = field(
fsdp_config: Optional[Union[dict[str, Any], str]] = field(
default=None,
metadata={
"help": (
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
"fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`."
)
},
@ -1366,7 +1366,7 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field(
default=None,
metadata={
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
@ -1451,7 +1451,7 @@ class TrainingArguments:
)
},
)
ddp_timeout: Optional[int] = field(
ddp_timeout: int = field(
default=1800,
metadata={
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
@ -1667,7 +1667,7 @@ class TrainingArguments:
) and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = not (self.metric_for_best_model.endswith("loss"))
self.greater_is_better = not self.metric_for_best_model.endswith("loss")
if self.run_name is None:
self.run_name = self.output_dir
if self.framework == "pt" and is_torch_available():
@ -2140,7 +2140,7 @@ class TrainingArguments:
f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False}
if isinstance(self.accelerator_config, AcceleratorConfig):
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
"use_configured_state", False

View File

@ -446,7 +446,7 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertEqual(
len(raw_dict_fields),
0,
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
f"Found invalid raw `dict` types in the `TrainingArgument` typings, which are {raw_dict_fields}. "
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
)