From ae3e4e2d97d7342a626ce5bd36e19fdcc07d4d35 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 21 May 2025 21:54:38 +0800 Subject: [PATCH] Improve typing in TrainingArgument (#36944) * Better error message in TrainingArgument typing checks * Better typing * Small fixes Signed-off-by: cyy --------- Signed-off-by: cyy --- src/transformers/training_args.py | 14 +++++++------- tests/utils/test_hf_argparser.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 3b1b8c58b5d..04a4972e0a7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -907,7 +907,7 @@ class TrainingArguments: default="linear", metadata={"help": "The scheduler type to use."}, ) - lr_scheduler_kwargs: Optional[Union[dict, str]] = field( + lr_scheduler_kwargs: Optional[Union[dict[str, Any], str]] = field( default_factory=dict, metadata={ "help": ( @@ -1230,11 +1230,11 @@ class TrainingArguments: ) }, ) - fsdp_config: Optional[Union[dict, str]] = field( + fsdp_config: Optional[Union[dict[str, Any], str]] = field( default=None, metadata={ "help": ( - "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a " + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a " "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." ) }, @@ -1366,7 +1366,7 @@ class TrainingArguments: "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." }, ) - gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field( + gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field( default=None, metadata={ "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." @@ -1451,7 +1451,7 @@ class TrainingArguments: ) }, ) - ddp_timeout: Optional[int] = field( + ddp_timeout: int = field( default=1800, metadata={ "help": "Overrides the default timeout for distributed training (value should be given in seconds)." @@ -1667,7 +1667,7 @@ class TrainingArguments: ) and self.metric_for_best_model is None: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: - self.greater_is_better = not (self.metric_for_best_model.endswith("loss")) + self.greater_is_better = not self.metric_for_best_model.endswith("loss") if self.run_name is None: self.run_name = self.output_dir if self.framework == "pt" and is_torch_available(): @@ -2140,7 +2140,7 @@ class TrainingArguments: f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" ) # We delay the init of `PartialState` to the end for clarity - accelerator_state_kwargs = {"enabled": True, "use_configured_state": False} + accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False} if isinstance(self.accelerator_config, AcceleratorConfig): accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( "use_configured_state", False diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index 6845ea746c4..dbcc5b00660 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -446,7 +446,7 @@ class HfArgumentParserTest(unittest.TestCase): self.assertEqual( len(raw_dict_fields), 0, - "Found invalid raw `dict` types in the `TrainingArgument` typings. " + f"Found invalid raw `dict` types in the `TrainingArgument` typings, which are {raw_dict_fields}. " "This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`", )