mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
optimization schedule test update
This commit is contained in:
parent
d164867d90
commit
b64cc63a77
@ -51,11 +51,18 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||||
def test_it(self):
|
def test_it(self):
|
||||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=3)
|
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5)
|
||||||
x = np.arange(0, 1000) / 1000
|
x = np.arange(0, 1000) / 1000
|
||||||
y = [m.get_lr_(xe) for xe in x]
|
y = [m.get_lr_(xe) for xe in x]
|
||||||
plt.plot(y)
|
plt.plot(y)
|
||||||
plt.show()
|
plt.show(block=False)
|
||||||
|
y = np.asarray(y)
|
||||||
|
expected_zeros = y[[0, 200, 400, 600, 800]]
|
||||||
|
print(expected_zeros)
|
||||||
|
expected_ones = y[[50, 250, 450, 650, 850]]
|
||||||
|
print(expected_ones)
|
||||||
|
self.assertTrue(np.allclose(expected_ones, 1))
|
||||||
|
self.assertTrue(np.allclose(expected_zeros, 0))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user