mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
- removed __all__ in optimization
- removed unused plotting code - using ABC for LRSchedule - added some schedule object init tests
This commit is contained in:
parent
34ccc8ebf4
commit
bb7557d3ab
@ -20,15 +20,12 @@ from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
|
||||
"WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"]
|
||||
|
||||
|
||||
class LRSchedule(object):
|
||||
class _LRSchedule(ABC):
|
||||
""" Parent of all LRSchedules here. """
|
||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
||||
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
||||
@ -37,7 +34,7 @@ class LRSchedule(object):
|
||||
:param t_total: how many training steps (updates) are planned
|
||||
:param kw:
|
||||
"""
|
||||
super(LRSchedule, self).__init__(**kw)
|
||||
super(_LRSchedule, self).__init__(**kw)
|
||||
if t_total < 0:
|
||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
@ -65,16 +62,21 @@ class LRSchedule(object):
|
||||
# end warning
|
||||
return ret
|
||||
|
||||
@abstractmethod
|
||||
def get_lr_(self, progress):
|
||||
"""
|
||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
||||
:return: learning rate multiplier for current update
|
||||
"""
|
||||
return 1.
|
||||
# raise NotImplemented("use subclass") -
|
||||
|
||||
|
||||
class WarmupCosineSchedule(LRSchedule):
|
||||
class ConstantLR(_LRSchedule):
|
||||
def get_lr_(self, progress):
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupCosineSchedule(_LRSchedule):
|
||||
"""
|
||||
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
|
||||
"""
|
||||
@ -135,7 +137,7 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
|
||||
return ret
|
||||
|
||||
|
||||
class WarmupConstantSchedule(LRSchedule):
|
||||
class WarmupConstantSchedule(_LRSchedule):
|
||||
"""
|
||||
Applies linear warmup. After warmup always returns 1..
|
||||
"""
|
||||
@ -145,7 +147,7 @@ class WarmupConstantSchedule(LRSchedule):
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupLinearSchedule(LRSchedule):
|
||||
class WarmupLinearSchedule(_LRSchedule):
|
||||
"""
|
||||
Linear warmup. Linear decay after warmup.
|
||||
"""
|
||||
@ -157,8 +159,8 @@ class WarmupLinearSchedule(LRSchedule):
|
||||
|
||||
|
||||
SCHEDULES = {
|
||||
None: LRSchedule,
|
||||
"none": LRSchedule,
|
||||
None: ConstantLR,
|
||||
"none": ConstantLR,
|
||||
"warmup_cosine": WarmupCosineSchedule,
|
||||
"warmup_constant": WarmupConstantSchedule,
|
||||
"warmup_linear": WarmupLinearSchedule
|
||||
@ -185,7 +187,7 @@ class BertAdam(Optimizer):
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
|
||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
@ -194,7 +196,7 @@ class BertAdam(Optimizer):
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
# initialize schedule object
|
||||
if not isinstance(schedule, LRSchedule):
|
||||
if not isinstance(schedule, _LRSchedule):
|
||||
schedule_type = SCHEDULES[schedule]
|
||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||
else:
|
||||
|
@ -20,7 +20,8 @@ from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
from .optimization import *
|
||||
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
|
||||
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +34,7 @@ class OpenAIAdam(Optimizer):
|
||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
|
||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
@ -42,7 +43,7 @@ class OpenAIAdam(Optimizer):
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
# initialize schedule object
|
||||
if not isinstance(schedule, LRSchedule):
|
||||
if not isinstance(schedule, _LRSchedule):
|
||||
schedule_type = SCHEDULES[schedule]
|
||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||
else:
|
||||
|
@ -21,10 +21,11 @@ import unittest
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import BertAdam
|
||||
from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule
|
||||
#from matplotlib import pyplot as plt
|
||||
from pytorch_pretrained_bert import OpenAIAdam
|
||||
from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
@ -49,13 +50,33 @@ class OptimizationTest(unittest.TestCase):
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
def test_bert_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
def test_openai_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||
x = np.arange(0, 1000)
|
||||
y = [m.get_lr(xe) for xe in x]
|
||||
# plt.plot(y)
|
||||
# plt.show(block=False)
|
||||
y = np.asarray(y)
|
||||
expected_zeros = y[[0, 200, 400, 600, 800]]
|
||||
print(expected_zeros)
|
||||
|
Loading…
Reference in New Issue
Block a user