Add Trainer support for ReduceLROnPlateau (#23010)

* Add Trainer support for ReduceLROnPlateau

Fixes #16503

* Remove training argument and add default instance

---------

Co-authored-by: mmeloux <maxime.meloux@loria.fr>
This commit is contained in:
Maxime Méloux 2023-04-28 15:17:30 +02:00 committed by GitHub
parent cf7baf4060
commit 9b435204b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 4 deletions

View File

@ -22,7 +22,7 @@ from typing import Callable, Iterable, Optional, Tuple, Union
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from .trainer_utils import SchedulerType
from .utils import logging
@ -49,6 +49,21 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
def get_reduce_on_plateau_schedule(optimizer: Optimizer):
"""
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
Return:
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
"""
return ReduceLROnPlateau(optimizer)
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
@ -309,6 +324,7 @@ TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
}
@ -335,7 +351,7 @@ def get_scheduler(
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`

View File

@ -1997,7 +1997,9 @@ class Trainer:
self.optimizer.step()
if optimizer_was_run and not self.deepspeed:
self.lr_scheduler.step()
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
@ -2288,6 +2290,10 @@ class Trainer:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(metrics[self.args.metric_for_best_model])
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

View File

@ -367,6 +367,7 @@ class SchedulerType(ExplicitEnum):
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
INVERSE_SQRT = "inverse_sqrt"
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
class TrainerMemoryTracker:

View File

@ -1194,7 +1194,9 @@ class TrainingArguments:
f"https://github.com/huggingface/safetensors!"
)
if self.load_best_model_at_end and self.metric_for_best_model is None:
if (
self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
) and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
@ -1234,6 +1236,12 @@ class TrainingArguments:
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO:
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
if not is_torch_available():
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
self.optim = OptimizerNames(self.optim)
if self.adafactor:
warnings.warn(

View File

@ -575,6 +575,74 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
def test_reduce_lr_on_plateau_args(self):
# test passed arguments for a custom ReduceLROnPlateau scheduler
train_dataset = RegressionDataset(length=64)
eval_dataset = RegressionDataset(length=64)
args = TrainingArguments(
"./regression",
evaluation_strategy="epoch",
metric_for_best_model="eval_loss",
)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2)
trainer = Trainer(
model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(optimizer, lr_scheduler)
)
trainer.train()
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
self.assertEqual(trainer.lr_scheduler.factor, 0.2)
self.assertEqual(trainer.lr_scheduler.patience, 5)
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
def test_reduce_lr_on_plateau(self):
# test the ReduceLROnPlateau scheduler
class TrainerWithLRLogs(Trainer):
def log(self, logs):
# the LR is computed after metrics and does not exist for the first epoch
if hasattr(self.lr_scheduler, "_last_lr"):
logs["learning_rate"] = self.lr_scheduler._last_lr
super().log(logs)
train_dataset = RegressionDataset(length=64)
eval_dataset = RegressionDataset(length=64)
args = TrainingArguments(
"./regression",
lr_scheduler_type="reduce_lr_on_plateau",
evaluation_strategy="epoch",
metric_for_best_model="eval_loss",
num_train_epochs=10,
learning_rate=0.2,
)
model = RegressionModel()
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
patience = trainer.lr_scheduler.patience
logs = trainer.state.log_history[1:]
best_loss = logs[0]["eval_loss"]
bad_epochs = 0
for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's
loss = log["eval_loss"]
just_decreased = False
if loss > best_loss:
bad_epochs += 1
if bad_epochs > patience:
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
just_decreased = True
bad_epochs = 0
else:
best_loss = loss
bad_epochs = 0
if not just_decreased:
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
def test_adafactor_lr_none(self):
# test the special case where lr=None, since Trainer can't not have lr_scheduler