mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add Trainer support for ReduceLROnPlateau (#23010)
* Add Trainer support for ReduceLROnPlateau Fixes #16503 * Remove training argument and add default instance --------- Co-authored-by: mmeloux <maxime.meloux@loria.fr>
This commit is contained in:
parent
cf7baf4060
commit
9b435204b1
@ -22,7 +22,7 @@ from typing import Callable, Iterable, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
||||
|
||||
from .trainer_utils import SchedulerType
|
||||
from .utils import logging
|
||||
@ -49,6 +49,21 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
||||
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_reduce_on_plateau_schedule(optimizer: Optimizer):
|
||||
"""
|
||||
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
return ReduceLROnPlateau(optimizer)
|
||||
|
||||
|
||||
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1.0, num_warmup_steps))
|
||||
@ -309,6 +324,7 @@ TYPE_TO_SCHEDULER_FUNCTION = {
|
||||
SchedulerType.CONSTANT: get_constant_schedule,
|
||||
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
||||
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
||||
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
|
||||
}
|
||||
|
||||
|
||||
@ -335,7 +351,7 @@ def get_scheduler(
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
|
@ -1997,7 +1997,9 @@ class Trainer:
|
||||
self.optimizer.step()
|
||||
|
||||
if optimizer_was_run and not self.deepspeed:
|
||||
self.lr_scheduler.step()
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
@ -2288,6 +2290,10 @@ class Trainer:
|
||||
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
||||
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
||||
|
||||
# Run delayed LR scheduler now that metrics are populated
|
||||
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.lr_scheduler.step(metrics[self.args.metric_for_best_model])
|
||||
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial, metrics=metrics)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
@ -367,6 +367,7 @@ class SchedulerType(ExplicitEnum):
|
||||
CONSTANT = "constant"
|
||||
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
||||
INVERSE_SQRT = "inverse_sqrt"
|
||||
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
|
||||
|
||||
|
||||
class TrainerMemoryTracker:
|
||||
|
@ -1194,7 +1194,9 @@ class TrainingArguments:
|
||||
f"https://github.com/huggingface/safetensors!"
|
||||
)
|
||||
|
||||
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
||||
if (
|
||||
self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
|
||||
) and self.metric_for_best_model is None:
|
||||
self.metric_for_best_model = "loss"
|
||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
|
||||
@ -1234,6 +1236,12 @@ class TrainingArguments:
|
||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||
raise ValueError("sharded_ddp is not supported with bf16")
|
||||
|
||||
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||
if self.evaluation_strategy == IntervalStrategy.NO:
|
||||
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
|
||||
if not is_torch_available():
|
||||
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
|
||||
|
||||
self.optim = OptimizerNames(self.optim)
|
||||
if self.adafactor:
|
||||
warnings.warn(
|
||||
|
@ -575,6 +575,74 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
def test_reduce_lr_on_plateau_args(self):
|
||||
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
args = TrainingArguments(
|
||||
"./regression",
|
||||
evaluation_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2)
|
||||
trainer = Trainer(
|
||||
model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(optimizer, lr_scheduler)
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
self.assertEqual(trainer.lr_scheduler.factor, 0.2)
|
||||
self.assertEqual(trainer.lr_scheduler.patience, 5)
|
||||
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
|
||||
|
||||
def test_reduce_lr_on_plateau(self):
|
||||
# test the ReduceLROnPlateau scheduler
|
||||
|
||||
class TrainerWithLRLogs(Trainer):
|
||||
def log(self, logs):
|
||||
# the LR is computed after metrics and does not exist for the first epoch
|
||||
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||
logs["learning_rate"] = self.lr_scheduler._last_lr
|
||||
super().log(logs)
|
||||
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
|
||||
args = TrainingArguments(
|
||||
"./regression",
|
||||
lr_scheduler_type="reduce_lr_on_plateau",
|
||||
evaluation_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
num_train_epochs=10,
|
||||
learning_rate=0.2,
|
||||
)
|
||||
model = RegressionModel()
|
||||
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
patience = trainer.lr_scheduler.patience
|
||||
|
||||
logs = trainer.state.log_history[1:]
|
||||
best_loss = logs[0]["eval_loss"]
|
||||
bad_epochs = 0
|
||||
for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's
|
||||
loss = log["eval_loss"]
|
||||
just_decreased = False
|
||||
if loss > best_loss:
|
||||
bad_epochs += 1
|
||||
if bad_epochs > patience:
|
||||
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||
just_decreased = True
|
||||
bad_epochs = 0
|
||||
else:
|
||||
best_loss = loss
|
||||
bad_epochs = 0
|
||||
if not just_decreased:
|
||||
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user