Support custom scheduler in deepspeed training (#26831)

Reuse trainer.create_scheduler to create scheduler for deepspeed
This commit is contained in:
Ziyang 2024-02-05 10:33:55 +08:00 committed by GitHub
parent ca8944c4e3
commit 7b702836af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,7 @@
"""
Integration with Deepspeed
"""
import copy
import importlib.metadata as importlib_metadata
import importlib.util
import weakref
@ -27,7 +27,6 @@ from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
import torch
from ..optimization import get_scheduler
logger = logging.get_logger(__name__)
@ -341,12 +340,15 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
if isinstance(optimizer, DummyOptim):
def _lr_scheduler_callable(optimizer):
return get_scheduler(
trainer.args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
# create a shallow copy first, so later modifications do not affect original trainer
trainer_copy = copy.copy(trainer)
# at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set
# update it to None so that we can re-create a new scheduler
trainer_copy.lr_scheduler = None
lr_scheduler = trainer_copy.create_scheduler(
num_training_steps=num_training_steps, optimizer=optimizer
)
return lr_scheduler
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else: