fix pt-1.9.0 add_ deprecation (#12217)

* fix pt-1.9.0 add_ deprecation

* add () for clarity

* Trigger CI

* require_version(torch
This commit is contained in:
Stas Bekman 2021-06-17 08:53:59 -07:00 committed by GitHub
parent 3a960c4857
commit d6ea91c96a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,6 +24,7 @@ from torch.optim.lr_scheduler import LambdaLR
from .trainer_utils import SchedulerType
from .utils import logging
from .utils.versions import require_version
logger = logging.get_logger(__name__)
@ -296,6 +297,7 @@ class AdamW(Optimizer):
weight_decay: float = 0.0,
correct_bias: bool = True,
):
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
@ -343,7 +345,7 @@ class AdamW(Optimizer):
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
@ -364,7 +366,7 @@ class AdamW(Optimizer):
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
return loss
@ -458,6 +460,7 @@ class Adafactor(Optimizer):
relative_step=True,
warmup_init=False,
):
require_version("torch>=1.5.0") # add_ with alpha
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
@ -566,8 +569,8 @@ class Adafactor(Optimizer):
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
@ -575,7 +578,7 @@ class Adafactor(Optimizer):
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
@ -583,11 +586,11 @@ class Adafactor(Optimizer):
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(-group["weight_decay"] * lr, p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)