From b64cc63a772dc981040ad8747efbf319ebb4945a Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 16:42:40 +0200 Subject: [PATCH] optimization schedule test update --- tests/optimization_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 8c28ad38adf..218da7581f8 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,11 +51,18 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=3) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5) x = np.arange(0, 1000) / 1000 y = [m.get_lr_(xe) for xe in x] plt.plot(y) - plt.show() + plt.show(block=False) + y = np.asarray(y) + expected_zeros = y[[0, 200, 400, 600, 800]] + print(expected_zeros) + expected_ones = y[[50, 250, 450, 650, 850]] + print(expected_ones) + self.assertTrue(np.allclose(expected_ones, 1)) + self.assertTrue(np.allclose(expected_zeros, 0))