mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Improve typing in TrainingArgument (#36944)
* Better error message in TrainingArgument typing checks * Better typing * Small fixes Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
174684a9b6
commit
ae3e4e2d97
@ -907,7 +907,7 @@ class TrainingArguments:
|
||||
default="linear",
|
||||
metadata={"help": "The scheduler type to use."},
|
||||
)
|
||||
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
|
||||
lr_scheduler_kwargs: Optional[Union[dict[str, Any], str]] = field(
|
||||
default_factory=dict,
|
||||
metadata={
|
||||
"help": (
|
||||
@ -1230,11 +1230,11 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
fsdp_config: Optional[Union[dict, str]] = field(
|
||||
fsdp_config: Optional[Union[dict[str, Any], str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
|
||||
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
|
||||
"fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`."
|
||||
)
|
||||
},
|
||||
@ -1366,7 +1366,7 @@ class TrainingArguments:
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
|
||||
gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
|
||||
@ -1451,7 +1451,7 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
ddp_timeout: Optional[int] = field(
|
||||
ddp_timeout: int = field(
|
||||
default=1800,
|
||||
metadata={
|
||||
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
|
||||
@ -1667,7 +1667,7 @@ class TrainingArguments:
|
||||
) and self.metric_for_best_model is None:
|
||||
self.metric_for_best_model = "loss"
|
||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||
self.greater_is_better = not (self.metric_for_best_model.endswith("loss"))
|
||||
self.greater_is_better = not self.metric_for_best_model.endswith("loss")
|
||||
if self.run_name is None:
|
||||
self.run_name = self.output_dir
|
||||
if self.framework == "pt" and is_torch_available():
|
||||
@ -2140,7 +2140,7 @@ class TrainingArguments:
|
||||
f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
# We delay the init of `PartialState` to the end for clarity
|
||||
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
|
||||
accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False}
|
||||
if isinstance(self.accelerator_config, AcceleratorConfig):
|
||||
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
|
||||
"use_configured_state", False
|
||||
|
@ -446,7 +446,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
len(raw_dict_fields),
|
||||
0,
|
||||
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
|
||||
f"Found invalid raw `dict` types in the `TrainingArgument` typings, which are {raw_dict_fields}. "
|
||||
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user