mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Allow scheduler parameters (#26480)
* Allow for scheduler kwargs * Formatting * Arguments checks, passing the tests * Black failed somehow --------- Co-authored-by: Pierre <pierre@avatarin.com>
This commit is contained in:
parent
ac5d4cf6de
commit
7e1eff7600
@ -337,6 +337,7 @@ def get_scheduler(
|
|||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
num_warmup_steps: Optional[int] = None,
|
num_warmup_steps: Optional[int] = None,
|
||||||
num_training_steps: Optional[int] = None,
|
num_training_steps: Optional[int] = None,
|
||||||
|
scheduler_specific_kwargs: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
@ -352,6 +353,9 @@ def get_scheduler(
|
|||||||
num_training_steps (`int``, *optional*):
|
num_training_steps (`int``, *optional*):
|
||||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
scheduler_specific_kwargs (`dict`, *optional*):
|
||||||
|
Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
|
||||||
|
parameters will cause the scheduler function to raise a TypeError.
|
||||||
"""
|
"""
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
@ -372,7 +376,15 @@ def get_scheduler(
|
|||||||
if num_training_steps is None:
|
if num_training_steps is None:
|
||||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||||
|
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
if scheduler_specific_kwargs is None:
|
||||||
|
scheduler_specific_kwargs = {}
|
||||||
|
|
||||||
|
return schedule_func(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
**scheduler_specific_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AdamW(Optimizer):
|
class AdamW(Optimizer):
|
||||||
|
@ -1137,6 +1137,7 @@ class Trainer:
|
|||||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
optimizer=self.optimizer if optimizer is None else optimizer,
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=num_training_steps,
|
num_training_steps=num_training_steps,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
)
|
)
|
||||||
self._created_lr_scheduler = True
|
self._created_lr_scheduler = True
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
@ -238,6 +238,8 @@ class TrainingArguments:
|
|||||||
when all data is exhausted
|
when all data is exhausted
|
||||||
lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
|
lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
|
||||||
The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
|
The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
|
||||||
|
lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
|
||||||
|
The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
|
||||||
warmup_ratio (`float`, *optional*, defaults to 0.0):
|
warmup_ratio (`float`, *optional*, defaults to 0.0):
|
||||||
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
|
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
|
||||||
warmup_steps (`int`, *optional*, defaults to 0):
|
warmup_steps (`int`, *optional*, defaults to 0):
|
||||||
@ -729,6 +731,14 @@ class TrainingArguments:
|
|||||||
default="linear",
|
default="linear",
|
||||||
metadata={"help": "The scheduler type to use."},
|
metadata={"help": "The scheduler type to use."},
|
||||||
)
|
)
|
||||||
|
lr_scheduler_kwargs: Optional[Dict] = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
warmup_ratio: float = field(
|
warmup_ratio: float = field(
|
||||||
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
|
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user