diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 14e4dd9b7cd..2e6ac0161c6 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -582,6 +582,7 @@ def get_scheduler( if name == SchedulerType.INVERSE_SQRT: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + # wsd scheduler requires either num_training_steps or num_stable_steps if name == SchedulerType.WARMUP_STABLE_DECAY: return schedule_func( optimizer,