mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Support custom scheduler in deepspeed training (#26831)
Reuse trainer.create_scheduler to create scheduler for deepspeed
This commit is contained in:
parent
ca8944c4e3
commit
7b702836af
@ -14,7 +14,7 @@
|
||||
"""
|
||||
Integration with Deepspeed
|
||||
"""
|
||||
|
||||
import copy
|
||||
import importlib.metadata as importlib_metadata
|
||||
import importlib.util
|
||||
import weakref
|
||||
@ -27,7 +27,6 @@ from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..optimization import get_scheduler
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -341,12 +340,15 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
||||
if isinstance(optimizer, DummyOptim):
|
||||
|
||||
def _lr_scheduler_callable(optimizer):
|
||||
return get_scheduler(
|
||||
trainer.args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
# create a shallow copy first, so later modifications do not affect original trainer
|
||||
trainer_copy = copy.copy(trainer)
|
||||
# at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set
|
||||
# update it to None so that we can re-create a new scheduler
|
||||
trainer_copy.lr_scheduler = None
|
||||
lr_scheduler = trainer_copy.create_scheduler(
|
||||
num_training_steps=num_training_steps, optimizer=optimizer
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user