mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Same fix for addcmul_
This commit is contained in:
parent
ad02c961c6
commit
55bda52555
@ -152,8 +152,8 @@ class AdamW(Optimizer):
|
|||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
# In-place operations to update the averages at the same time
|
# In-place operations to update the averages at the same time
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
|
exp_avg.mul_(beta1).add_(grad, 1.0 - beta1)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1.0 - beta2)
|
||||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||||
|
|
||||||
step_size = group["lr"]
|
step_size = group["lr"]
|
||||||
@ -173,6 +173,6 @@ class AdamW(Optimizer):
|
|||||||
# of the weights to the loss with plain (non-momentum) SGD.
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
# Add weight decay at the end (fixed version)
|
# Add weight decay at the end (fixed version)
|
||||||
if group["weight_decay"] > 0.0:
|
if group["weight_decay"] > 0.0:
|
||||||
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
|
p.data.add_(p.data, -group["lr"] * group["weight_decay"])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
Loading…
Reference in New Issue
Block a user