Just import torch AdamW instead (#36177)

* Just import torch AdamW instead

* Update docs too

* Make AdamW undocumented

* make fixup

* Add a basic wrapper class

* Add it back to the docs

* Just remove AdamW entirely

* Remove some AdamW references

* Drop AdamW from the public init

* make fix-copies

* Cleanup some references

* make fixup

* Delete lots of transformers.AdamW references

* Remove extra references to adamw_hf
This commit is contained in:
Matt 2025-03-19 18:29:40 +00:00 committed by GitHub
parent 51bd0ceb9e
commit 9be4728af8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 18 additions and 174 deletions

View File

@ -22,9 +22,6 @@ The `.optimization` module provides:
- several schedules in the form of schedule objects that inherit from `_LRSchedule`:
- a gradient accumulation class to accumulate the gradients of multiple batches
## AdamW (PyTorch)
[[autodoc]] AdamW
## AdaFactor (PyTorch)

View File

@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
- `_LRSchedule` から継承するスケジュール オブジェクトの形式のいくつかのスケジュール:
- 複数のバッチの勾配を累積するための勾配累積クラス
## AdamW (PyTorch)
[[autodoc]] AdamW
## AdaFactor (PyTorch)
[[autodoc]] Adafactor

View File

@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
- 继承自 `_LRSchedule` 多个调度器:
- 一个梯度累积类,用于累积多个批次的梯度
## AdamW (PyTorch)
[[autodoc]] AdamW
## AdaFactor (PyTorch)
[[autodoc]] Adafactor

View File

@ -8,7 +8,6 @@ import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
@ -20,6 +19,7 @@ from transformers import (
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
is_torch_available,
)
from transformers.optimization import (
Adafactor,
@ -31,6 +31,10 @@ from transformers.optimization import (
from transformers.utils.versions import require_version
if is_torch_available():
import torch
logger = logging.getLogger(__name__)
require_version("pytorch_lightning>=1.0.4")
@ -146,7 +150,7 @@ class BaseTransformer(pl.LightningModule):
)
else:
optimizer = AdamW(
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer

View File

@ -32,7 +32,6 @@ import transformers
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
@ -96,7 +95,7 @@ def train(args, train_dataset, model, tokenizer):
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)

View File

@ -43,7 +43,6 @@ from tqdm import tqdm, trange
from transformers import (
CONFIG_NAME,
WEIGHTS_NAME,
AdamW,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTTokenizer,
get_linear_schedule_with_warmup,
@ -236,7 +235,7 @@ def main():
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)

View File

@ -34,7 +34,6 @@ from tqdm import tqdm, trange
import transformers
from transformers import (
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
@ -298,7 +297,7 @@ def train(args, train_dataset, model, tokenizer):
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)

View File

@ -22,7 +22,6 @@ from transformers import PreTrainedModel, Trainer, logging
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
@ -102,12 +101,11 @@ class Seq2SeqTrainer(Trainer):
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_cls = torch.optim.AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,

View File

@ -41,7 +41,6 @@ from utils_qa import postprocess_qa_predictions_with_beam_search
import transformers
from transformers import (
AdamW,
DataCollatorWithPadding,
EvalPrediction,
SchedulerType,
@ -767,7 +766,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False

View File

@ -33,7 +33,6 @@ from tqdm.auto import tqdm
import transformers
from transformers import (
AdamW,
SchedulerType,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
@ -583,7 +582,7 @@ def main():
)
# Optimizer
optimizer = AdamW(
optimizer = torch.optim.AdamW(
list(model.parameters()),
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],

View File

@ -4111,7 +4111,6 @@ else:
)
_import_structure["optimization"] = [
"Adafactor",
"AdamW",
"get_constant_schedule",
"get_constant_schedule_with_warmup",
"get_cosine_schedule_with_warmup",
@ -8758,7 +8757,6 @@ if TYPE_CHECKING:
# Optimization
from .optimization import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,

View File

@ -17,10 +17,9 @@
import math
import warnings
from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
@ -604,120 +603,6 @@ def get_scheduler(
)
class AdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
no_deprecation_warning: bool = False,
):
if not no_deprecation_warning:
warnings.warn(
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
" warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss
class Adafactor(Optimizer):
"""
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:

View File

@ -1421,11 +1421,6 @@ class Trainer:
if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim == OptimizerNames.ADAMW_HF:
from .optimization import AdamW
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
from torch.optim import AdamW

View File

@ -146,7 +146,6 @@ class OptimizerNames(ExplicitEnum):
Stores the acceptable string identifiers for optimizers.
"""
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
@ -628,7 +627,7 @@ class TrainingArguments:
The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
The optimizer to use, such as "adamw_hf", "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
for a full list of optimizers.
optim_args (`str`, *optional*):
@ -2986,7 +2985,7 @@ class TrainingArguments:
Args:
name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
The optimizer to use: `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
`"adamw_anyprecision"` or `"adafactor"`.
learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate.

View File

@ -10856,13 +10856,6 @@ class Adafactor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AdamW(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])

View File

@ -535,7 +535,6 @@ from accelerate import Accelerator
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
{{cookiecutter.model_class}},
AutoTokenizer,
@ -863,7 +862,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(

View File

@ -28,7 +28,6 @@ if is_torch_available():
from transformers import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
@ -76,7 +75,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5])
criterion = nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
optimizer = torch.optim.AdamW(params=[w], lr=2e-1, weight_decay=0.0)
for _ in range(100):
loss = criterion(w, target)
loss.backward()
@ -114,7 +113,7 @@ class OptimizationTest(unittest.TestCase):
@require_torch
class ScheduleInitTest(unittest.TestCase):
m = nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
optimizer = torch.optim.AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol, msg=None):

View File

@ -5375,16 +5375,6 @@ if is_torch_available():
}
optim_test_params = [
(
OptimizerNames.ADAMW_HF,
transformers.optimization.AdamW,
default_adam_kwargs,
),
(
OptimizerNames.ADAMW_HF.value,
transformers.optimization.AdamW,
default_adam_kwargs,
),
(
OptimizerNames.ADAMW_TORCH,
torch.optim.AdamW,