mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fixing optimization
This commit is contained in:
parent
852e4b3c00
commit
088ad45888
@ -4,16 +4,19 @@ from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
s = 1 if x <= warmup else 0
|
||||
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x)))
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
s = 1 if x <= warmup else 0
|
||||
return s*(x/warmup) + (1-s)*1
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
s = 1 if x <= warmup else 0
|
||||
return (s*(x/warmup) + (1-s))*(1-x)
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0 - x
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine':warmup_cosine,
|
||||
@ -24,24 +27,34 @@ SCHEDULES = {
|
||||
|
||||
class BERTAdam(Optimizer):
|
||||
"""Implements Open AI version of Adam algorithm with weight decay fix.
|
||||
Params:
|
||||
lr,
|
||||
warmup=-1,
|
||||
t_total=-1,
|
||||
schedule='warmup_linear',
|
||||
b1=0.9,
|
||||
b2=0.999,
|
||||
e=1e-6,
|
||||
weight_decay_rate=0.01,
|
||||
max_grad_norm=1.0
|
||||
"""
|
||||
def __init__(self, params, lr, schedule, warmup, t_total,
|
||||
b1=0.9, b2=0.999, e=1e-6, l2=0,
|
||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
|
||||
max_grad_norm=1.0):
|
||||
if not lr >= 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0 <= warmup:
|
||||
raise ValueError("Invalid warmup: {}".format(warmup))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {}".format(b1))
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {}".format(b2))
|
||||
if not 0.0 <= e:
|
||||
raise ValueError("Invalid epsilon value: {}".format(e))
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2,
|
||||
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BERTAdam, self).__init__(params, defaults)
|
||||
|
||||
@ -52,8 +65,11 @@ class BERTAdam(Optimizer):
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
@ -103,32 +119,22 @@ class BERTAdam(Optimizer):
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['next_m'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['next_v'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
next_m, next_v = state['next_m'], state['next_v']
|
||||
beta1, beta2 = group['b1'], group['b2']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
denom = exp_avg_sq.sqrt().add_(group['e'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
# In-place operations to update the averages at the same time
|
||||
next_m.mul_(beta1).add_(1 - beta1, grad)
|
||||
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
update = next_m / (next_v.sqrt() + group['e'])
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
@ -137,7 +143,22 @@ class BERTAdam(Optimizer):
|
||||
# Instead we want ot decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
|
||||
p.data.add_(-lr_scheduled * group['l2'], p.data)
|
||||
if group['weight_decay_rate'] > 0.0:
|
||||
update += group['weight_decay_rate'] * p.data
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
update_with_lr = lr_scheduled * update
|
||||
p.data.add_(-update_with_lr)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
# bias_correction1 = 1 - beta1 ** state['step']
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
return loss
|
||||
|
@ -31,13 +31,18 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
def test_adam(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
x = torch.tensor([0.4, 0.2, -0.5])
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
||||
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = optimization.BERTAdam(params=[w], lr=2e-1,
|
||||
weight_decay_rate=0.0,
|
||||
max_grad_norm=-1)
|
||||
for _ in range(100):
|
||||
loss = criterion(w, x)
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
|
||||
w.grad.zero_()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
|
@ -483,10 +483,14 @@ def main():
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||
],
|
||||
lr=args.learning_rate, schedule='warmup_linear',
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
|
||||
optimizer = BERTAdam(optimizer_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
|
@ -38,10 +38,16 @@ class OptimizationTest(tf.test.TestCase):
|
||||
init_op = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(100):
|
||||
np_w = sess.run(w)
|
||||
np_loss = sess.run(loss)
|
||||
np_grad = sess.run(grads)[0]
|
||||
for i in range(100):
|
||||
print(i)
|
||||
sess.run(train_op)
|
||||
w_np = sess.run(w)
|
||||
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
||||
np_w = sess.run(w)
|
||||
np_loss = sess.run(loss)
|
||||
np_grad = sess.run(grads)[0]
|
||||
self.assertAllClose(np_w.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user