mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
added warning
This commit is contained in:
parent
9bc3773c84
commit
46ef646016
@ -35,17 +35,6 @@ def warmup_constant(x, warmup=0.002):
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
class Warmup_Linear_with_Warning(object):
|
||||
def __init__(self, **kw):
|
||||
super(Warmup_Linear_with_Warning, self).__init__()
|
||||
self.warned_at_x = -1
|
||||
|
||||
def __call__(self, x, warmup=0.002):
|
||||
if x > 1 and x > self.warned_at_x:
|
||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
||||
self.warned_at_x = x
|
||||
return warmup_linear(x, warmup=warmup)
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
@ -54,9 +43,9 @@ def warmup_linear(x, warmup=0.002):
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine':warmup_cosine,
|
||||
'warmup_constant':warmup_constant,
|
||||
'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear,
|
||||
'warmup_cosine': warmup_cosine,
|
||||
'warmup_constant': warmup_constant,
|
||||
'warmup_linear': warmup_linear,
|
||||
}
|
||||
|
||||
|
||||
@ -93,6 +82,8 @@ class BertAdam(Optimizer):
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
# warning for t_total exceeded
|
||||
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf")
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
@ -163,7 +154,15 @@ class BertAdam(Optimizer):
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
progress = state['step']/group['t_total']
|
||||
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps. Learning rate set to zero. "
|
||||
"Please set 't_total' of {} correctly.".format(self.__class__.__name__))
|
||||
self._warned_for_t_total_at_progress = progress
|
||||
# end warning
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
|
@ -40,8 +40,6 @@ def warmup_linear(x, warmup=0.002):
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
if x > 1:
|
||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
@ -73,6 +71,8 @@ class OpenAIAdam(Optimizer):
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(OpenAIAdam, self).__init__(params, defaults)
|
||||
# warning for t_total exceeded
|
||||
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf")
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
@ -137,7 +137,15 @@ class OpenAIAdam(Optimizer):
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
progress = state['step']/group['t_total']
|
||||
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps. Learning rate set to zero. "
|
||||
"Please set 't_total' of {} correctly.".format(self.__class__.__name__))
|
||||
self._warned_for_t_total_at_progress = progress
|
||||
# end warning
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user