mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Allow scheduler parameters (#26480)
* Allow for scheduler kwargs * Formatting * Arguments checks, passing the tests * Black failed somehow --------- Co-authored-by: Pierre <pierre@avatarin.com>
This commit is contained in:
parent
ac5d4cf6de
commit
7e1eff7600
@ -337,6 +337,7 @@ def get_scheduler(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
scheduler_specific_kwargs: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
@ -352,6 +353,9 @@ def get_scheduler(
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
scheduler_specific_kwargs (`dict`, *optional*):
|
||||
Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
|
||||
parameters will cause the scheduler function to raise a TypeError.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
@ -372,7 +376,15 @@ def get_scheduler(
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
if scheduler_specific_kwargs is None:
|
||||
scheduler_specific_kwargs = {}
|
||||
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
**scheduler_specific_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
|
@ -1137,6 +1137,7 @@ class Trainer:
|
||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
self._created_lr_scheduler = True
|
||||
return self.lr_scheduler
|
||||
|
@ -238,6 +238,8 @@ class TrainingArguments:
|
||||
when all data is exhausted
|
||||
lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
|
||||
The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
|
||||
lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
|
||||
The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
|
||||
warmup_ratio (`float`, *optional*, defaults to 0.0):
|
||||
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
|
||||
warmup_steps (`int`, *optional*, defaults to 0):
|
||||
@ -729,6 +731,14 @@ class TrainingArguments:
|
||||
default="linear",
|
||||
metadata={"help": "The scheduler type to use."},
|
||||
)
|
||||
lr_scheduler_kwargs: Optional[Dict] = field(
|
||||
default_factory=dict,
|
||||
metadata={
|
||||
"help": (
|
||||
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
|
||||
)
|
||||
},
|
||||
)
|
||||
warmup_ratio: float = field(
|
||||
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user