Allow scheduler parameters (#26480)

* Allow for scheduler kwargs

* Formatting

* Arguments checks, passing the tests

* Black failed somehow

---------

Co-authored-by: Pierre <pierre@avatarin.com>
This commit is contained in:
Plemeur 2023-11-08 06:40:00 +09:00 committed by GitHub
parent ac5d4cf6de
commit 7e1eff7600
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 1 deletions

View File

@ -337,6 +337,7 @@ def get_scheduler(
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
scheduler_specific_kwargs: Optional[dict] = None,
):
"""
Unified API to get any scheduler from its name.
@ -352,6 +353,9 @@ def get_scheduler(
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
scheduler_specific_kwargs (`dict`, *optional*):
Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
parameters will cause the scheduler function to raise a TypeError.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
@ -372,7 +376,15 @@ def get_scheduler(
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
if scheduler_specific_kwargs is None:
scheduler_specific_kwargs = {}
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**scheduler_specific_kwargs,
)
class AdamW(Optimizer):

View File

@ -1137,6 +1137,7 @@ class Trainer:
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
**self.args.lr_scheduler_kwargs,
)
self._created_lr_scheduler = True
return self.lr_scheduler

View File

@ -238,6 +238,8 @@ class TrainingArguments:
when all data is exhausted
lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
warmup_ratio (`float`, *optional*, defaults to 0.0):
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
warmup_steps (`int`, *optional*, defaults to 0):
@ -729,6 +731,14 @@ class TrainingArguments:
default="linear",
metadata={"help": "The scheduler type to use."},
)
lr_scheduler_kwargs: Optional[Dict] = field(
default_factory=dict,
metadata={
"help": (
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
)
},
)
warmup_ratio: float = field(
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
)