diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 124813b22ab..b3861b371a2 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -53,19 +53,22 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) -def get_reduce_on_plateau_schedule(optimizer: Optimizer): +def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs): """ Create a schedule with a constant learning rate that decreases when a metric has stopped improving. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. + kwargs (`dict`, *optional*): + Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau` + for possible parameters. Return: `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. """ - return ReduceLROnPlateau(optimizer) + return ReduceLROnPlateau(optimizer, **kwargs) def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): @@ -359,9 +362,15 @@ def get_scheduler( """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU: + if name == SchedulerType.CONSTANT: return schedule_func(optimizer) + if scheduler_specific_kwargs is None: + scheduler_specific_kwargs = {} + + if name == SchedulerType.REDUCE_ON_PLATEAU: + return schedule_func(optimizer, **scheduler_specific_kwargs) + # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -376,9 +385,6 @@ def get_scheduler( if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - if scheduler_specific_kwargs is None: - scheduler_specific_kwargs = {} - return schedule_func( optimizer, num_warmup_steps=num_warmup_steps,