Support constant lr with cooldown (#35453)

* Add support for constant learning rate with cooldown

* Add support for constant learning rate with cooldown

* Add support for constant learning rate with cooldown

* Add support for constant learning rate with cooldown

* Add support for constant learning rate with cooldown

* Add support for constant learning rate with cooldown

* Add support for constant learning rate with cooldown

* Add more warmup and cooldown methods to 'get_wsc_schedule'

* Add more warmup and cooldown methods to 'get_wsc_schedule'

* Add more warmup and cooldown methods to 'get_wsc_schedule'

* Add more warmup and cooldown methods to 'get_wsc_schedule'

* Add more warmup and decay methods to 'get_wsd_schedule'

* support num_training_steps and num_stable_steps for get_wsd_schedule

* support num_training_steps and num_stable_steps for get_wsd_schedule

* get wsd scheduler before the `num_training_steps` decision

* fix code_quality

* Update stable branch logic

* fix code_quality

* Move stable stage decide to `get_wsd_schedule`

* Update docstring of `get_wsd_schedule`

* Update `num_train_steps` to optional

* Update `num_train_steps` to optional

* Update docstring of `get_wsd_schedule`

* Update src/transformers/optimization.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Jingze Shi 2025-02-10 20:21:55 +08:00 committed by GitHub
parent 9a6be63fdb
commit 48a309d0d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 17 deletions

View File

@ -393,45 +393,71 @@ def _get_wsd_scheduler_lambda(
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
num_cycles: float,
warmup_type: str,
decay_type: str,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step) / float(max(1, num_warmup_steps))
if warmup_type == "linear":
factor = progress
elif warmup_type == "cosine":
factor = 0.5 * (1.0 - math.cos(math.pi * progress))
elif warmup_type == "1-sqrt":
factor = 1.0 - math.sqrt(1.0 - progress)
factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
return max(0.0, factor)
if current_step < num_warmup_steps + num_stable_steps:
return 1.0
if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return (1.0 - min_lr_ratio) * value + min_lr_ratio
if decay_type == "linear":
factor = 1.0 - progress
elif decay_type == "cosine":
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
elif decay_type == "1-sqrt":
factor = 1.0 - math.sqrt(progress)
factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
return max(0.0, factor)
return min_lr_ratio
def get_wsd_schedule(
optimizer: Optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
num_training_steps: Optional[int] = None,
num_stable_steps: Optional[int] = None,
warmup_type: str = "linear",
decay_type: str = "cosine",
min_lr_ratio: float = 0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that has three stages:
1. linear increase from 0 to initial lr.
2. constant lr (equal to initial lr).
3. decrease following the values of the cosine function between the initial lr set in the optimizer to
a fraction of initial lr.
1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
2. stable: constant learning rate.
3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_stable_steps (`int`):
The number of steps for the stable phase.
num_decay_steps (`int`):
The number of steps for the cosine annealing phase.
The number of steps for the decay phase.
num_training_steps (`int`, *optional*):
The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
num_stable_steps (`int`, *optional*):
The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
warmup_type (`str`, *optional*, defaults to "linear"):
The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
decay_type (`str`, *optional*, defaults to "cosine"):
The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
min_lr_ratio (`float`, *optional*, defaults to 0):
The minimum learning rate as a ratio of the initial learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
@ -443,11 +469,29 @@ def get_wsd_schedule(
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
if num_training_steps is None and num_stable_steps is None:
raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
if num_training_steps is not None and num_stable_steps is not None:
warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
if warmup_type not in ["linear", "cosine", "1-sqrt"]:
raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
if decay_type not in ["linear", "cosine", "1-sqrt"]:
raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
if num_stable_steps is None:
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
lr_lambda = partial(
_get_wsd_scheduler_lambda,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
warmup_type=warmup_type,
decay_type=decay_type,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
@ -541,7 +585,12 @@ def get_scheduler(
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs)
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**scheduler_specific_kwargs,
)
# All other schedulers require `num_training_steps`
if num_training_steps is None:

View File

@ -153,8 +153,8 @@ class ScheduleInitTest(unittest.TestCase):
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
),
get_wsd_schedule: (
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
{**common_kwargs, "num_decay_steps": 2, "min_lr_ratio": 0.0},
[0.0, 5.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 5.0],
),
}
@ -183,14 +183,34 @@ class ScheduleInitTest(unittest.TestCase):
"name": "warmup_stable_decay",
"optimizer": self.optimizer,
"num_warmup_steps": 2,
"scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3},
"num_training_steps": 10,
"scheduler_specific_kwargs": {
"num_decay_steps": 2,
"warmup_type": "linear",
"decay_type": "linear",
},
},
{
"name": "warmup_stable_decay",
"optimizer": self.optimizer,
"num_warmup_steps": 2,
"num_training_steps": 10,
"scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3},
"scheduler_specific_kwargs": {
"num_decay_steps": 2,
"warmup_type": "cosine",
"decay_type": "cosine",
},
},
{
"name": "warmup_stable_decay",
"optimizer": self.optimizer,
"num_warmup_steps": 2,
"num_training_steps": 10,
"scheduler_specific_kwargs": {
"num_decay_steps": 2,
"warmup_type": "1-sqrt",
"decay_type": "1-sqrt",
},
},
{"name": "cosine", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10},
]