From b6c1cae67b139c6c27112147375b04ca1cfec3f8 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Mon, 18 Mar 2019 13:32:04 +0100 Subject: [PATCH 01/13] branches, optim cosine fix --- pytorch_pretrained_bert/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 9a873e221b8..eb24c3bd37a 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -52,7 +52,7 @@ class LRSchedule(object): def get_lr_(self, step): return 1. - # raise NotImplemented("use subclass") + # raise NotImplemented("use subclass") - class WarmupCosineSchedule(LRSchedule): From 262a9992d7ab348dfc35bda6c550fbbba8f5bc42 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Mon, 18 Mar 2019 18:29:12 +0100 Subject: [PATCH 02/13] class weights --- pytorch_pretrained_bert/optimization.py | 21 ++++++++++++++++++--- tests/optimization_test.py | 15 ++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index eb24c3bd37a..a39a18cea3c 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -24,7 +24,8 @@ import logging logger = logging.getLogger(__name__) -__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", "WarmupCosineWithRestartsSchedule"] +__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", + "WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"] class LRSchedule(object): @@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule): return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) -class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): +class WarmupMultiCosineSchedule(WarmupCosineSchedule): warn_t_total = True def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): - super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) + super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) + assert(cycles >= 1.) def get_lr_(self, progress): if self.t_total <= 0: @@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): return ret +class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule): + def get_lr_(self, progress): + if self.t_total <= 0.: + return 1. + progress = progress * self.cycles % 1. + if progress < self.warmup: + return progress / self.warmup + else: + progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup + ret = 0.5 * (1. + math.cos(math.pi * progress)) + return ret + + class WarmupConstantSchedule(LRSchedule): warn_t_total = False def get_lr_(self, progress): diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 848b9d1cf5c..3f9f8abbfe8 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -20,7 +20,9 @@ import unittest import torch -from pytorch_pretrained_bert import BertAdam +from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule +from matplotlib import pyplot as plt +import numpy as np class OptimizationTest(unittest.TestCase): @@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase): self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) +class WarmupCosineWithRestartsTest(unittest.TestCase): + def test_it(self): + m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3) + x = np.arange(0, 1000) / 1000 + y = [m.get_lr_(xe) for xe in x] + plt.plot(y) + plt.show() + + + + if __name__ == "__main__": unittest.main() From 1758c8fc722bc2b8a80bca6786d891fbe46fb7a2 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 16:08:34 +0200 Subject: [PATCH 03/13] - updated docs for optimization --- pytorch_pretrained_bert/optimization.py | 68 +++++++++++++----- .../optimization_openai.py | 71 +++++-------------- 2 files changed, 70 insertions(+), 69 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index a39a18cea3c..565d3bff453 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -25,12 +25,18 @@ logger = logging.getLogger(__name__) __all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", - "WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"] + "WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"] class LRSchedule(object): - warn_t_total = False + """ Parent of all LRSchedules here. """ + warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense def __init__(self, warmup=0.002, t_total=-1, **kw): + """ + :param warmup: what fraction of t_total steps will be used for linear warmup + :param t_total: how many training steps (updates) are planned + :param kw: + """ super(LRSchedule, self).__init__(**kw) self.warmup, self.t_total = warmup, t_total if t_total <= 0: @@ -40,6 +46,11 @@ class LRSchedule(object): self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): + """ + :param step: which of t_total steps we're on + :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps + :return: learning rate multiplier for current update + """ progress = step / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear @@ -51,14 +62,27 @@ class LRSchedule(object): # end warning return ret - def get_lr_(self, step): + def get_lr_(self, progress): + """ + :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress + :return: learning rate multiplier for current update + """ return 1. # raise NotImplemented("use subclass") - class WarmupCosineSchedule(LRSchedule): + """ + Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts. + """ warn_t_total = True def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): + """ + :param warmup: see LRSchedule + :param t_total: see LRSchedule + :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. + :param kw: + """ super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) self.cycles = cycles @@ -73,10 +97,12 @@ class WarmupCosineSchedule(LRSchedule): return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) -class WarmupMultiCosineSchedule(WarmupCosineSchedule): - warn_t_total = True +class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): + """ + Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1). + """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): - super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) + super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) assert(cycles >= 1.) def get_lr_(self, progress): @@ -90,7 +116,16 @@ class WarmupMultiCosineSchedule(WarmupCosineSchedule): return ret -class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule): +class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): + """ + Cosine learning rate schedule with linear warmups and linear warmup restarts. + The same warmup rate is used for warmup restarts as for initial warmup. + The total effective fraction of warmup steps over all cycles is warmup * cycles! + """ + def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): + assert(warmup * cycles < 1.) + super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw) + def get_lr_(self, progress): if self.t_total <= 0.: return 1. @@ -104,7 +139,9 @@ class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule): class WarmupConstantSchedule(LRSchedule): - warn_t_total = False + """ + Applies linear warmup. After warmup always returns 1.. + """ def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup @@ -112,6 +149,9 @@ class WarmupConstantSchedule(LRSchedule): class WarmupLinearSchedule(LRSchedule): + """ + Linear warmup. Linear decay after warmup. + """ warn_t_total = True def get_lr_(self, progress): if progress < self.warmup: @@ -145,8 +185,7 @@ class BertAdam(Optimizer): max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 """ def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', - b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0., - max_grad_norm=1.0): + b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: @@ -163,9 +202,10 @@ class BertAdam(Optimizer): schedule = schedule_type(warmup=warmup, t_total=t_total) else: if warmup != -1 or t_total != -1: - logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.") + logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. " + "Please specify custom warmup and t_total in LRSchedule object.") defaults = dict(lr=lr, schedule=schedule, - b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay, + b1=b1, b2=b2, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) @@ -176,10 +216,8 @@ class BertAdam(Optimizer): state = self.state[p] if len(state) == 0: return [0] - lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) - lr.append(lr_scheduled) return lr @@ -235,8 +273,6 @@ class BertAdam(Optimizer): if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data - # TODO: init weight decay - lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) diff --git a/pytorch_pretrained_bert/optimization_openai.py b/pytorch_pretrained_bert/optimization_openai.py index 99ac15e1089..5bfea476a68 100644 --- a/pytorch_pretrained_bert/optimization_openai.py +++ b/pytorch_pretrained_bert/optimization_openai.py @@ -20,35 +20,10 @@ from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging +from .optimization import * logger = logging.getLogger(__name__) -def warmup_cosine(x, warmup=0.002): - if x < warmup: - return x/warmup - x_ = (x - warmup) / (1 - warmup) # progress after warmup - return 0.5 * (1. + math.cos(math.pi * x_)) - -def warmup_constant(x, warmup=0.002): - """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps. - Learning rate is 1. afterwards. """ - if x < warmup: - return x/warmup - return 1.0 - -def warmup_linear(x, warmup=0.002): - """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step. - After `t_total`-th training step, learning rate is zero. """ - if x < warmup: - return x/warmup - return max((x-1.)/(warmup-1.), 0) - -SCHEDULES = { - 'warmup_cosine':warmup_cosine, - 'warmup_constant':warmup_constant, - 'warmup_linear':warmup_linear, -} - class OpenAIAdam(Optimizer): """Implements Open AI version of Adam algorithm with weight decay fix. @@ -58,17 +33,23 @@ class OpenAIAdam(Optimizer): vector_l2=False, max_grad_norm=-1, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) - if schedule not in SCHEDULES: + if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) - if not 0.0 <= warmup < 1.0 and not warmup == -1: - raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) if not 0.0 <= b1 < 1.0: - raise ValueError("Invalid b1 parameter: {}".format(b1)) + raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) if not 0.0 <= b2 < 1.0: - raise ValueError("Invalid b2 parameter: {}".format(b2)) + raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) if not e >= 0.0: - raise ValueError("Invalid epsilon value: {}".format(e)) - defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) + # initialize schedule object + if not isinstance(schedule, LRSchedule): + schedule_type = SCHEDULES[schedule] + schedule = schedule_type(warmup=warmup, t_total=t_total) + else: + if warmup != -1 or t_total != -1: + logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. " + "Please specify custom warmup and t_total in LRSchedule object.") + defaults = dict(lr=lr, schedule=schedule, b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, max_grad_norm=max_grad_norm) super(OpenAIAdam, self).__init__(params, defaults) @@ -80,11 +61,8 @@ class OpenAIAdam(Optimizer): state = self.state[p] if len(state) == 0: return [0] - if group['t_total'] != -1: - schedule_fct = SCHEDULES[group['schedule']] - lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) - else: - lr_scheduled = group['lr'] + lr_scheduled = group['lr'] + lr_scheduled *= group['schedule'].get_lr(state['step']) lr.append(lr_scheduled) return lr @@ -99,8 +77,6 @@ class OpenAIAdam(Optimizer): if closure is not None: loss = closure() - warned_for_t_total = False - for group in self.param_groups: for p in group['params']: if p.grad is None: @@ -136,19 +112,8 @@ class OpenAIAdam(Optimizer): bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] - if group['t_total'] != -1: - schedule_fct = SCHEDULES[group['schedule']] - progress = state['step']/group['t_total'] - lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) - # warning for exceeding t_total (only active with warmup_linear - if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: - logger.warning( - "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " - "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) - warned_for_t_total = True - # end warning - else: - lr_scheduled = group['lr'] + lr_scheduled = group['lr'] + lr_scheduled *= group['schedule'].get_lr(state['step']) step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 From d164867d90c7b352445aa7d4028a6ba156a70a77 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 16:13:51 +0200 Subject: [PATCH 04/13] - updated docs for optimization --- tests/optimization_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 3f9f8abbfe8..8c28ad38adf 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -20,7 +20,8 @@ import unittest import torch -from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule +from pytorch_pretrained_bert import BertAdam +from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule from matplotlib import pyplot as plt import numpy as np @@ -50,7 +51,7 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=3) x = np.arange(0, 1000) / 1000 y = [m.get_lr_(xe) for xe in x] plt.plot(y) From b64cc63a772dc981040ad8747efbf319ebb4945a Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 16:42:40 +0200 Subject: [PATCH 05/13] optimization schedule test update --- tests/optimization_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 8c28ad38adf..218da7581f8 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,11 +51,18 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=3) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5) x = np.arange(0, 1000) / 1000 y = [m.get_lr_(xe) for xe in x] plt.plot(y) - plt.show() + plt.show(block=False) + y = np.asarray(y) + expected_zeros = y[[0, 200, 400, 600, 800]] + print(expected_zeros) + expected_ones = y[[50, 250, 450, 650, 850]] + print(expected_ones) + self.assertTrue(np.allclose(expected_ones, 1)) + self.assertTrue(np.allclose(expected_zeros, 0)) From 91a073f80458ef7ca65f8b1e4af35b8061155794 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 17:10:08 +0200 Subject: [PATCH 06/13] schedule fix --- pytorch_pretrained_bert/optimization.py | 17 +++++++---------- tests/optimization_test.py | 6 +++--- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 565d3bff453..8c8dc3b8624 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -38,11 +38,12 @@ class LRSchedule(object): :param kw: """ super(LRSchedule, self).__init__(**kw) - self.warmup, self.t_total = warmup, t_total if t_total <= 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) + warmup = max(warmup, 0) + self.warmup, self.t_total = warmup, t_total self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): @@ -51,6 +52,8 @@ class LRSchedule(object): :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps :return: learning rate multiplier for current update """ + if self.t_total < 0: + return 1. progress = step / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear @@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule): self.cycles = cycles def get_lr_(self, progress): - """ get learning rate multiplier """ - if self.t_total <= 0: - return 1. if progress < self.warmup: return progress / self.warmup else: @@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): assert(cycles >= 1.) def get_lr_(self, progress): - if self.t_total <= 0: - return 1. if progress < self.warmup: return progress / self.warmup else: @@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): assert(warmup * cycles < 1.) - super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw) + warmup = warmup * cycles if warmup >= 0 else warmup + super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) def get_lr_(self, progress): - if self.t_total <= 0.: - return 1. progress = progress * self.cycles % 1. if progress < self.warmup: return progress / self.warmup @@ -174,7 +171,7 @@ class BertAdam(Optimizer): lr: learning rate warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning - rate schedule, -1 means constant learning rate. Default: -1 + rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 schedule: schedule to use for the warmup (see above). Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object. Default: 'warmup_linear' diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 218da7581f8..0eaae16d310 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5) - x = np.arange(0, 1000) / 1000 - y = [m.get_lr_(xe) for xe in x] + m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5) + x = np.arange(0, 1000) + y = [m.get_lr(xe) for xe in x] plt.plot(y) plt.show(block=False) y = np.asarray(y) From 23bd2eebf53ddd92eac4a1d4589b773028556246 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 17:10:34 +0200 Subject: [PATCH 07/13] schedule fix --- tests/optimization_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 0eaae16d310..f3147c8998c 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,7 +51,7 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5) x = np.arange(0, 1000) y = [m.get_lr(xe) for xe in x] plt.plot(y) From 5fed5bb3d687c9eafe04ec5e22f937c5355e53ce Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 17:20:29 +0200 Subject: [PATCH 08/13] schedule fix --- pytorch_pretrained_bert/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 8c8dc3b8624..92cf2b05eb8 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -38,7 +38,7 @@ class LRSchedule(object): :param kw: """ super(LRSchedule, self).__init__(**kw) - if t_total <= 0: + if t_total < 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) From 1b4ce76c3885357bbaa975b30ad32d4e1f47f032 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 17:40:12 +0200 Subject: [PATCH 09/13] schedule fix --- tests/optimization_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index f3147c8998c..80216cc8d42 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -22,7 +22,7 @@ import torch from pytorch_pretrained_bert import BertAdam from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule -from matplotlib import pyplot as plt +#from matplotlib import pyplot as plt import numpy as np class OptimizationTest(unittest.TestCase): @@ -54,8 +54,8 @@ class WarmupCosineWithRestartsTest(unittest.TestCase): m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5) x = np.arange(0, 1000) y = [m.get_lr(xe) for xe in x] - plt.plot(y) - plt.show(block=False) + # plt.plot(y) + # plt.show(block=False) y = np.asarray(y) expected_zeros = y[[0, 200, 400, 600, 800]] print(expected_zeros) From 20686b78fc786bf662b4ed1bd743823aeef57fd8 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 18:13:52 +0200 Subject: [PATCH 10/13] schedule fix --- pytorch_pretrained_bert/optimization.py | 6 +++--- tests/optimization_test.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 92cf2b05eb8..df5b50b51df 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -42,8 +42,8 @@ class LRSchedule(object): logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) - warmup = max(warmup, 0) - self.warmup, self.t_total = warmup, t_total + warmup = max(warmup, 0.) + self.warmup, self.t_total = float(warmup), float(t_total) self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): @@ -153,7 +153,7 @@ class WarmupLinearSchedule(LRSchedule): def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup - return max((progress - 1.) / (self.warmup - 1.), 0) + return max((progress - 1.) / (self.warmup - 1.), 0.) SCHEDULES = { diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 80216cc8d42..e74f4bba6ca 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,7 +51,7 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) x = np.arange(0, 1000) y = [m.get_lr(xe) for xe in x] # plt.plot(y) From fc7693adc33484942a92ebba63117bf166883c0e Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 18:16:47 +0200 Subject: [PATCH 11/13] schedule fix --- pytorch_pretrained_bert/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index df5b50b51df..ca973015a67 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -54,7 +54,7 @@ class LRSchedule(object): """ if self.t_total < 0: return 1. - progress = step / self.t_total + progress = float(step) / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: From bb7557d3ab96f139997bfaa70ff2b4a6c18994e0 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Sun, 21 Apr 2019 13:48:33 +0200 Subject: [PATCH 12/13] - removed __all__ in optimization - removed unused plotting code - using ABC for LRSchedule - added some schedule object init tests --- pytorch_pretrained_bert/optimization.py | 30 ++++++++++--------- .../optimization_openai.py | 7 +++-- tests/optimization_test.py | 29 +++++++++++++++--- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index ca973015a67..d2d4f7f5e5a 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -20,15 +20,12 @@ from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging +from abc import ABC, abstractmethod logger = logging.getLogger(__name__) -__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", - "WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"] - - -class LRSchedule(object): +class _LRSchedule(ABC): """ Parent of all LRSchedules here. """ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense def __init__(self, warmup=0.002, t_total=-1, **kw): @@ -37,7 +34,7 @@ class LRSchedule(object): :param t_total: how many training steps (updates) are planned :param kw: """ - super(LRSchedule, self).__init__(**kw) + super(_LRSchedule, self).__init__(**kw) if t_total < 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: @@ -65,16 +62,21 @@ class LRSchedule(object): # end warning return ret + @abstractmethod def get_lr_(self, progress): """ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress :return: learning rate multiplier for current update """ return 1. - # raise NotImplemented("use subclass") - -class WarmupCosineSchedule(LRSchedule): +class ConstantLR(_LRSchedule): + def get_lr_(self, progress): + return 1. + + +class WarmupCosineSchedule(_LRSchedule): """ Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts. """ @@ -135,7 +137,7 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul return ret -class WarmupConstantSchedule(LRSchedule): +class WarmupConstantSchedule(_LRSchedule): """ Applies linear warmup. After warmup always returns 1.. """ @@ -145,7 +147,7 @@ class WarmupConstantSchedule(LRSchedule): return 1. -class WarmupLinearSchedule(LRSchedule): +class WarmupLinearSchedule(_LRSchedule): """ Linear warmup. Linear decay after warmup. """ @@ -157,8 +159,8 @@ class WarmupLinearSchedule(LRSchedule): SCHEDULES = { - None: LRSchedule, - "none": LRSchedule, + None: ConstantLR, + "none": ConstantLR, "warmup_cosine": WarmupCosineSchedule, "warmup_constant": WarmupConstantSchedule, "warmup_linear": WarmupLinearSchedule @@ -185,7 +187,7 @@ class BertAdam(Optimizer): b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) - if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: + if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) @@ -194,7 +196,7 @@ class BertAdam(Optimizer): if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object - if not isinstance(schedule, LRSchedule): + if not isinstance(schedule, _LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: diff --git a/pytorch_pretrained_bert/optimization_openai.py b/pytorch_pretrained_bert/optimization_openai.py index 5bfea476a68..0cf0494e206 100644 --- a/pytorch_pretrained_bert/optimization_openai.py +++ b/pytorch_pretrained_bert/optimization_openai.py @@ -20,7 +20,8 @@ from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging -from .optimization import * +from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ + WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ class OpenAIAdam(Optimizer): vector_l2=False, max_grad_norm=-1, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) - if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: + if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) @@ -42,7 +43,7 @@ class OpenAIAdam(Optimizer): if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object - if not isinstance(schedule, LRSchedule): + if not isinstance(schedule, _LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: diff --git a/tests/optimization_test.py b/tests/optimization_test.py index e74f4bba6ca..f52aeb506b3 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -21,10 +21,11 @@ import unittest import torch from pytorch_pretrained_bert import BertAdam -from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule -#from matplotlib import pyplot as plt +from pytorch_pretrained_bert import OpenAIAdam +from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule import numpy as np + class OptimizationTest(unittest.TestCase): def assertListAlmostEqual(self, list1, list2, tol): @@ -49,13 +50,33 @@ class OptimizationTest(unittest.TestCase): self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) +class ScheduleInitTest(unittest.TestCase): + def test_bert_sched_init(self): + m = torch.nn.Linear(50, 50) + optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) + # shouldn't fail + + def test_openai_sched_init(self): + m = torch.nn.Linear(50, 50) + optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) + # shouldn't fail + + class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) x = np.arange(0, 1000) y = [m.get_lr(xe) for xe in x] - # plt.plot(y) - # plt.show(block=False) y = np.asarray(y) expected_zeros = y[[0, 200, 400, 600, 800]] print(expected_zeros) From 69850b40114095aaa093adaf4ef2181cfe4176ed Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Sun, 21 Apr 2019 14:02:38 +0200 Subject: [PATCH 13/13] python 2 compat --- pytorch_pretrained_bert/optimization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index d2d4f7f5e5a..7e88b1b61c6 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -20,11 +20,18 @@ from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging -from abc import ABC, abstractmethod +import abc +import sys logger = logging.getLogger(__name__) +if sys.version_info >= (3, 4): + ABC = abc.ABC +else: + ABC = abc.ABCMeta('ABC', (), {}) + + class _LRSchedule(ABC): """ Parent of all LRSchedules here. """ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense @@ -62,7 +69,7 @@ class _LRSchedule(ABC): # end warning return ret - @abstractmethod + @abc.abstractmethod def get_lr_(self, progress): """ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress