mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Just import torch AdamW instead (#36177)
* Just import torch AdamW instead * Update docs too * Make AdamW undocumented * make fixup * Add a basic wrapper class * Add it back to the docs * Just remove AdamW entirely * Remove some AdamW references * Drop AdamW from the public init * make fix-copies * Cleanup some references * make fixup * Delete lots of transformers.AdamW references * Remove extra references to adamw_hf
This commit is contained in:
parent
51bd0ceb9e
commit
9be4728af8
@ -22,9 +22,6 @@ The `.optimization` module provides:
|
||||
- several schedules in the form of schedule objects that inherit from `_LRSchedule`:
|
||||
- a gradient accumulation class to accumulate the gradients of multiple batches
|
||||
|
||||
## AdamW (PyTorch)
|
||||
|
||||
[[autodoc]] AdamW
|
||||
|
||||
## AdaFactor (PyTorch)
|
||||
|
||||
|
@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
|
||||
- `_LRSchedule` から継承するスケジュール オブジェクトの形式のいくつかのスケジュール:
|
||||
- 複数のバッチの勾配を累積するための勾配累積クラス
|
||||
|
||||
## AdamW (PyTorch)
|
||||
|
||||
[[autodoc]] AdamW
|
||||
|
||||
## AdaFactor (PyTorch)
|
||||
|
||||
[[autodoc]] Adafactor
|
||||
|
@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
|
||||
- 继承自 `_LRSchedule` 多个调度器:
|
||||
- 一个梯度累积类,用于累积多个批次的梯度
|
||||
|
||||
## AdamW (PyTorch)
|
||||
|
||||
[[autodoc]] AdamW
|
||||
|
||||
## AdaFactor (PyTorch)
|
||||
|
||||
[[autodoc]] Adafactor
|
||||
|
@ -8,7 +8,6 @@ import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
@ -20,6 +19,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
@ -31,6 +31,10 @@ from transformers.optimization import (
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
require_version("pytorch_lightning>=1.0.4")
|
||||
@ -146,7 +150,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
)
|
||||
|
||||
else:
|
||||
optimizer = AdamW(
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
||||
)
|
||||
self.opt = optimizer
|
||||
|
@ -32,7 +32,6 @@ import transformers
|
||||
from transformers import (
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoTokenizer,
|
||||
@ -96,7 +95,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
@ -43,7 +43,6 @@ from tqdm import tqdm, trange
|
||||
from transformers import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
OpenAIGPTTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
@ -236,7 +235,7 @@ def main():
|
||||
},
|
||||
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
@ -34,7 +34,6 @@ from tqdm import tqdm, trange
|
||||
import transformers
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
@ -298,7 +297,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
@ -22,7 +22,6 @@ from transformers import PreTrainedModel, Trainer, logging
|
||||
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
@ -102,12 +101,11 @@ class Seq2SeqTrainer(Trainer):
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer_cls = Adafactor if self.args.adafactor else AdamW
|
||||
if self.args.adafactor:
|
||||
optimizer_cls = Adafactor
|
||||
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
|
||||
else:
|
||||
optimizer_cls = AdamW
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
optimizer_kwargs = {
|
||||
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
||||
"eps": self.args.adam_epsilon,
|
||||
|
@ -41,7 +41,6 @@ from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AdamW,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
SchedulerType,
|
||||
@ -767,7 +766,7 @@ def main():
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
|
@ -33,7 +33,6 @@ from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AdamW,
|
||||
SchedulerType,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
@ -583,7 +582,7 @@ def main():
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = AdamW(
|
||||
optimizer = torch.optim.AdamW(
|
||||
list(model.parameters()),
|
||||
lr=args.learning_rate,
|
||||
betas=[args.adam_beta1, args.adam_beta2],
|
||||
|
@ -4111,7 +4111,6 @@ else:
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
"Adafactor",
|
||||
"AdamW",
|
||||
"get_constant_schedule",
|
||||
"get_constant_schedule_with_warmup",
|
||||
"get_cosine_schedule_with_warmup",
|
||||
@ -8758,7 +8757,6 @@ if TYPE_CHECKING:
|
||||
# Optimization
|
||||
from .optimization import (
|
||||
Adafactor,
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
|
@ -17,10 +17,9 @@
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
||||
|
||||
@ -604,120 +603,6 @@ def get_scheduler(
|
||||
)
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
||||
Regularization](https://arxiv.org/abs/1711.05101).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.001):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
eps (`float`, *optional*, defaults to 1e-06):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Decoupled weight decay to apply.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterable[nn.parameter.Parameter],
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
no_deprecation_warning: bool = False,
|
||||
):
|
||||
if not no_deprecation_warning:
|
||||
warnings.warn(
|
||||
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
|
||||
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
|
||||
" warning",
|
||||
FutureWarning,
|
||||
)
|
||||
require_version("torch>=1.5.0") # add_ with alpha
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class Adafactor(Optimizer):
|
||||
"""
|
||||
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
||||
|
@ -1421,11 +1421,6 @@ class Trainer:
|
||||
if args.optim == OptimizerNames.ADAFACTOR:
|
||||
optimizer_cls = Adafactor
|
||||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||
elif args.optim == OptimizerNames.ADAMW_HF:
|
||||
from .optimization import AdamW
|
||||
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
|
||||
from torch.optim import AdamW
|
||||
|
||||
|
@ -146,7 +146,6 @@ class OptimizerNames(ExplicitEnum):
|
||||
Stores the acceptable string identifiers for optimizers.
|
||||
"""
|
||||
|
||||
ADAMW_HF = "adamw_hf"
|
||||
ADAMW_TORCH = "adamw_torch"
|
||||
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||
@ -628,7 +627,7 @@ class TrainingArguments:
|
||||
|
||||
The options should be separated by whitespaces.
|
||||
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
|
||||
The optimizer to use, such as "adamw_hf", "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
|
||||
The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
|
||||
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
|
||||
for a full list of optimizers.
|
||||
optim_args (`str`, *optional*):
|
||||
@ -2986,7 +2985,7 @@ class TrainingArguments:
|
||||
|
||||
Args:
|
||||
name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
|
||||
The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
|
||||
The optimizer to use: `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
|
||||
`"adamw_anyprecision"` or `"adafactor"`.
|
||||
learning_rate (`float`, *optional*, defaults to 5e-5):
|
||||
The initial learning rate.
|
||||
|
@ -10856,13 +10856,6 @@ class Adafactor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AdamW(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch"])
|
||||
|
||||
|
@ -535,7 +535,6 @@ from accelerate import Accelerator
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
{{cookiecutter.model_class}},
|
||||
AutoTokenizer,
|
||||
@ -863,7 +862,7 @@ def main():
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
|
@ -28,7 +28,6 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
Adafactor,
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
@ -76,7 +75,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
optimizer = torch.optim.AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
for _ in range(100):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
@ -114,7 +113,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
m = nn.Linear(50, 50) if is_torch_available() else None
|
||||
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||
optimizer = torch.optim.AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||
num_steps = 10
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol, msg=None):
|
||||
|
@ -5375,16 +5375,6 @@ if is_torch_available():
|
||||
}
|
||||
|
||||
optim_test_params = [
|
||||
(
|
||||
OptimizerNames.ADAMW_HF,
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_HF.value,
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_TORCH,
|
||||
torch.optim.AdamW,
|
||||
|
Loading…
Reference in New Issue
Block a user