mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #445 from lukovnikov/master
Learning rate schedules improvement + extension
This commit is contained in:
commit
98cb7b2c51
@ -20,33 +20,157 @@ from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
import abc
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
x_ = (x - warmup) / (1 - warmup) # progress after warmup -
|
||||
return 0.5 * (1. + math.cos(math.pi * x_))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
||||
Learning rate is 1. afterwards. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
if sys.version_info >= (3, 4):
|
||||
ABC = abc.ABC
|
||||
else:
|
||||
ABC = abc.ABCMeta('ABC', (), {})
|
||||
|
||||
|
||||
class _LRSchedule(ABC):
|
||||
""" Parent of all LRSchedules here. """
|
||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
||||
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
||||
"""
|
||||
:param warmup: what fraction of t_total steps will be used for linear warmup
|
||||
:param t_total: how many training steps (updates) are planned
|
||||
:param kw:
|
||||
"""
|
||||
super(_LRSchedule, self).__init__(**kw)
|
||||
if t_total < 0:
|
||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
warmup = max(warmup, 0.)
|
||||
self.warmup, self.t_total = float(warmup), float(t_total)
|
||||
self.warned_for_t_total_at_progress = -1
|
||||
|
||||
def get_lr(self, step, nowarn=False):
|
||||
"""
|
||||
:param step: which of t_total steps we're on
|
||||
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
||||
:return: learning rate multiplier for current update
|
||||
"""
|
||||
if self.t_total < 0:
|
||||
return 1.
|
||||
progress = float(step) / self.t_total
|
||||
ret = self.get_lr_(progress)
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
|
||||
.format(ret, self.__class__.__name__))
|
||||
self.warned_for_t_total_at_progress = progress
|
||||
# end warning
|
||||
return ret
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_lr_(self, progress):
|
||||
"""
|
||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
||||
:return: learning rate multiplier for current update
|
||||
"""
|
||||
return 1.
|
||||
|
||||
|
||||
class ConstantLR(_LRSchedule):
|
||||
def get_lr_(self, progress):
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupCosineSchedule(_LRSchedule):
|
||||
"""
|
||||
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
|
||||
"""
|
||||
warn_t_total = True
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
|
||||
"""
|
||||
:param warmup: see LRSchedule
|
||||
:param t_total: see LRSchedule
|
||||
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
|
||||
:param kw:
|
||||
"""
|
||||
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
|
||||
self.cycles = cycles
|
||||
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
||||
|
||||
|
||||
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
||||
"""
|
||||
Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1).
|
||||
"""
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||
assert(cycles >= 1.)
|
||||
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
|
||||
return ret
|
||||
|
||||
|
||||
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
|
||||
"""
|
||||
Cosine learning rate schedule with linear warmups and linear warmup restarts.
|
||||
The same warmup rate is used for warmup restarts as for initial warmup.
|
||||
The total effective fraction of warmup steps over all cycles is warmup * cycles!
|
||||
"""
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||
assert(warmup * cycles < 1.)
|
||||
warmup = warmup * cycles if warmup >= 0 else warmup
|
||||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||
|
||||
def get_lr_(self, progress):
|
||||
progress = progress * self.cycles % 1.
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||
ret = 0.5 * (1. + math.cos(math.pi * progress))
|
||||
return ret
|
||||
|
||||
|
||||
class WarmupConstantSchedule(_LRSchedule):
|
||||
"""
|
||||
Applies linear warmup. After warmup always returns 1..
|
||||
"""
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupLinearSchedule(_LRSchedule):
|
||||
"""
|
||||
Linear warmup. Linear decay after warmup.
|
||||
"""
|
||||
warn_t_total = True
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine': warmup_cosine,
|
||||
'warmup_constant': warmup_constant,
|
||||
'warmup_linear': warmup_linear,
|
||||
None: ConstantLR,
|
||||
"none": ConstantLR,
|
||||
"warmup_cosine": WarmupCosineSchedule,
|
||||
"warmup_constant": WarmupConstantSchedule,
|
||||
"warmup_linear": WarmupLinearSchedule
|
||||
}
|
||||
|
||||
|
||||
@ -56,8 +180,10 @@ class BertAdam(Optimizer):
|
||||
lr: learning rate
|
||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||
t_total: total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default: -1
|
||||
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
||||
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
|
||||
schedule: schedule to use for the warmup (see above).
|
||||
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
|
||||
Default: 'warmup_linear'
|
||||
b1: Adams b1. Default: 0.9
|
||||
b2: Adams b2. Default: 0.999
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
@ -65,21 +191,26 @@ class BertAdam(Optimizer):
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
"""
|
||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
||||
max_grad_norm=1.0):
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
# initialize schedule object
|
||||
if not isinstance(schedule, _LRSchedule):
|
||||
schedule_type = SCHEDULES[schedule]
|
||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||
else:
|
||||
if warmup != -1 or t_total != -1:
|
||||
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
|
||||
"Please specify custom warmup and t_total in LRSchedule object.")
|
||||
defaults = dict(lr=lr, schedule=schedule,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
@ -91,11 +222,8 @@ class BertAdam(Optimizer):
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
@ -110,8 +238,6 @@ class BertAdam(Optimizer):
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
warned_for_t_total = False
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
@ -153,19 +279,8 @@ class BertAdam(Optimizer):
|
||||
if group['weight_decay'] > 0.0:
|
||||
update += group['weight_decay'] * p.data
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
progress = state['step']/group['t_total']
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
||||
warned_for_t_total = True
|
||||
# end warning
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
|
||||
update_with_lr = lr_scheduled * update
|
||||
p.data.add_(-update_with_lr)
|
||||
|
@ -20,35 +20,11 @@ from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
|
||||
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
x_ = (x - warmup) / (1 - warmup) # progress after warmup
|
||||
return 0.5 * (1. + math.cos(math.pi * x_))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
|
||||
Learning rate is 1. afterwards. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine':warmup_cosine,
|
||||
'warmup_constant':warmup_constant,
|
||||
'warmup_linear':warmup_linear,
|
||||
}
|
||||
|
||||
|
||||
class OpenAIAdam(Optimizer):
|
||||
"""Implements Open AI version of Adam algorithm with weight decay fix.
|
||||
@ -58,17 +34,23 @@ class OpenAIAdam(Optimizer):
|
||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {}".format(b1))
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {}".format(b2))
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {}".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
# initialize schedule object
|
||||
if not isinstance(schedule, _LRSchedule):
|
||||
schedule_type = SCHEDULES[schedule]
|
||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||
else:
|
||||
if warmup != -1 or t_total != -1:
|
||||
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
|
||||
"Please specify custom warmup and t_total in LRSchedule object.")
|
||||
defaults = dict(lr=lr, schedule=schedule,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(OpenAIAdam, self).__init__(params, defaults)
|
||||
@ -80,11 +62,8 @@ class OpenAIAdam(Optimizer):
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
@ -99,8 +78,6 @@ class OpenAIAdam(Optimizer):
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
warned_for_t_total = False
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
@ -136,19 +113,8 @@ class OpenAIAdam(Optimizer):
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
progress = state['step']/group['t_total']
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
||||
warned_for_t_total = True
|
||||
# end warning
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
|
||||
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
|
@ -21,6 +21,10 @@ import unittest
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import BertAdam
|
||||
from pytorch_pretrained_bert import OpenAIAdam
|
||||
from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
@ -46,5 +50,43 @@ class OptimizationTest(unittest.TestCase):
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
def test_bert_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
def test_openai_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||
x = np.arange(0, 1000)
|
||||
y = [m.get_lr(xe) for xe in x]
|
||||
y = np.asarray(y)
|
||||
expected_zeros = y[[0, 200, 400, 600, 800]]
|
||||
print(expected_zeros)
|
||||
expected_ones = y[[50, 250, 450, 650, 850]]
|
||||
print(expected_ones)
|
||||
self.assertTrue(np.allclose(expected_ones, 1))
|
||||
self.assertTrue(np.allclose(expected_zeros, 0))
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user