mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Deprecates AdamW and adds --optim
(#14744)
* Add AdamW deprecation warning * Add --optim to Trainer * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py * fix style * fix * Regroup adamws together Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Change --adafactor to --optim adafactor * Use Enum for optimizer values * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Use Enum for optimizer values * Improved documentation for --adafactor Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Add mention of no_deprecation_warning Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename OptimizerOptions to OptimizerNames * Use choices for --optim * Move optimizer selection code to a function and add a unit test * Change optimizer names * Rename method Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename method Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Remove TODO comment Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename variable Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename variable Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename function * Rename variable * Parameterize the tests for supported optimizers * Refactor * Attempt to make tests pass on CircleCI * Add a test with apex * rework to add apex to parameterized; add actual train test * fix import when torch is not available * fix optim_test_params when torch is not available * fix optim_test_params when torch is not available * re-org * small re-org * fix test_fused_adam_no_apex * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove .value from OptimizerNames * Rename optimizer strings s|--adam_|--adamw_| * Also rename Enum options * small fix * Fix instantiation of OptimizerNames. Remove redundant test * Use ExplicitEnum instead of Enum * Add unit test with string optimizer * Change optimizer default to string value Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
762416ffa8
commit
7b83feb50a
@ -15,6 +15,7 @@
|
||||
"""PyTorch optimization for BERT model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Callable, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -287,6 +288,8 @@ class AdamW(Optimizer):
|
||||
Decoupled weight decay to apply.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -297,7 +300,14 @@ class AdamW(Optimizer):
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
no_deprecation_warning: bool = False,
|
||||
):
|
||||
if not no_deprecation_warning:
|
||||
warnings.warn(
|
||||
"This implementation of AdamW is deprecated and will be removed in a future version. Use the"
|
||||
"PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning",
|
||||
FutureWarning,
|
||||
)
|
||||
require_version("torch>=1.5.0") # add_ with alpha
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
|
@ -77,7 +77,7 @@ from .file_utils import (
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
from .optimization import Adafactor, AdamW, get_scheduler
|
||||
from .optimization import Adafactor, get_scheduler
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_callback import (
|
||||
CallbackHandler,
|
||||
@ -128,7 +128,7 @@ from .trainer_utils import (
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
)
|
||||
from .training_args import ParallelMode, TrainingArguments
|
||||
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@ -819,17 +819,9 @@ class Trainer:
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer_cls = Adafactor if self.args.adafactor else AdamW
|
||||
if self.args.adafactor:
|
||||
optimizer_cls = Adafactor
|
||||
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
|
||||
else:
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs = {
|
||||
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
||||
"eps": self.args.adam_epsilon,
|
||||
}
|
||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
||||
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer = OSS(
|
||||
params=optimizer_grouped_parameters,
|
||||
@ -844,6 +836,46 @@ class Trainer:
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Returns the optimizer class and optimizer parameters based on the training arguments.
|
||||
|
||||
Args:
|
||||
args (`transformers.training_args.TrainingArguments`):
|
||||
The training arguments for the training session.
|
||||
|
||||
"""
|
||||
optimizer_kwargs = {"lr": args.learning_rate}
|
||||
adam_kwargs = {
|
||||
"betas": (args.adam_beta1, args.adam_beta2),
|
||||
"eps": args.adam_epsilon,
|
||||
}
|
||||
if args.optim == OptimizerNames.ADAFACTOR:
|
||||
optimizer_cls = Adafactor
|
||||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||
elif args.optim == OptimizerNames.ADAMW_HF:
|
||||
from .optimization import AdamW
|
||||
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif args.optim == OptimizerNames.ADAMW_TORCH:
|
||||
from torch.optim import AdamW
|
||||
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
|
||||
try:
|
||||
from apex.optimizers import FusedAdam
|
||||
|
||||
optimizer_cls = FusedAdam
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
except ImportError:
|
||||
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
|
||||
else:
|
||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
|
@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from .debug_utils import DebugOption
|
||||
from .file_utils import (
|
||||
ExplicitEnum,
|
||||
cached_property,
|
||||
get_full_repo_name,
|
||||
is_sagemaker_dp_enabled,
|
||||
@ -69,6 +70,17 @@ def default_logdir() -> str:
|
||||
return os.path.join("runs", current_time + "_" + socket.gethostname())
|
||||
|
||||
|
||||
class OptimizerNames(ExplicitEnum):
|
||||
"""
|
||||
Stores the acceptable string identifiers for optimizers.
|
||||
"""
|
||||
|
||||
ADAMW_HF = "adamw_hf"
|
||||
ADAMW_TORCH = "adamw_torch"
|
||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||
ADAFACTOR = "adafactor"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
"""
|
||||
@ -327,8 +339,10 @@ class TrainingArguments:
|
||||
- `"tpu_metrics_debug"`: print debug metrics on TPU
|
||||
|
||||
The options should be separated by whitespaces.
|
||||
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
|
||||
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor.
|
||||
adafactor (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use the [`Adafactor`] optimizer instead of [`AdamW`].
|
||||
This argument is deprecated. Use `--optim adafactor` instead.
|
||||
group_by_length (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
|
||||
padding applied and be more efficient). Only useful if applying dynamic padding.
|
||||
@ -641,6 +655,10 @@ class TrainingArguments:
|
||||
label_smoothing_factor: float = field(
|
||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||
)
|
||||
optim: OptimizerNames = field(
|
||||
default="adamw_hf",
|
||||
metadata={"help": "The optimizer to use."},
|
||||
)
|
||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||
group_by_length: bool = field(
|
||||
default=False,
|
||||
@ -809,6 +827,15 @@ class TrainingArguments:
|
||||
)
|
||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||
raise ValueError("sharded_ddp is not supported with bf16")
|
||||
|
||||
self.optim = OptimizerNames(self.optim)
|
||||
if self.adafactor:
|
||||
warnings.warn(
|
||||
"`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim adafactor` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.optim = OptimizerNames.ADAFACTOR
|
||||
|
||||
if (
|
||||
is_torch_available()
|
||||
and self.device.type != "cuda"
|
||||
|
@ -23,10 +23,12 @@ import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@ -36,7 +38,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
@ -61,6 +63,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@ -69,6 +72,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
import transformers.optimization
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
EarlyStoppingCallback,
|
||||
@ -1711,3 +1715,98 @@ class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
|
||||
)
|
||||
|
||||
|
||||
optim_test_params = []
|
||||
if is_torch_available():
|
||||
default_adam_kwargs = {
|
||||
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
|
||||
"eps": TrainingArguments.adam_epsilon,
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
}
|
||||
|
||||
optim_test_params = [
|
||||
(
|
||||
OptimizerNames.ADAMW_HF,
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_HF.value,
|
||||
transformers.optimization.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAMW_TORCH,
|
||||
torch.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
),
|
||||
(
|
||||
OptimizerNames.ADAFACTOR,
|
||||
transformers.optimization.Adafactor,
|
||||
{
|
||||
"scale_parameter": False,
|
||||
"relative_step": False,
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
},
|
||||
),
|
||||
]
|
||||
if is_apex_available():
|
||||
import apex
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
OptimizerNames.ADAMW_APEX_FUSED,
|
||||
apex.optimizers.FusedAdam,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls):
|
||||
args = TrainingArguments(optim=optim, output_dir="None")
|
||||
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
self.assertEqual(expected_cls, actual_cls)
|
||||
self.assertIsNotNone(optim_kwargs)
|
||||
|
||||
for p, v in mandatory_kwargs.items():
|
||||
self.assertTrue(p in optim_kwargs)
|
||||
actual_v = optim_kwargs[p]
|
||||
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
|
||||
|
||||
@parameterized.expand(optim_test_params, skip_on_empty=True)
|
||||
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
|
||||
# exercises all the valid --optim options
|
||||
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls)
|
||||
|
||||
trainer = get_regression_trainer(optim=name)
|
||||
trainer.train()
|
||||
|
||||
def test_fused_adam(self):
|
||||
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
|
||||
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||
# the test to run without requiring an apex installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"apex": mock,
|
||||
"apex.optimizers": mock.optimizers,
|
||||
"apex.optimizers.FusedAdam": mock.optimizers.FusedAdam,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
OptimizerNames.ADAMW_APEX_FUSED,
|
||||
default_adam_kwargs,
|
||||
mock.optimizers.FusedAdam,
|
||||
)
|
||||
|
||||
def test_fused_adam_no_apex(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None")
|
||||
|
||||
# Pretend that apex does not exist, even if installed. By setting apex to None, importing
|
||||
# apex will fail even if apex is installed.
|
||||
with patch.dict("sys.modules", {"apex.optimizers": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
Loading…
Reference in New Issue
Block a user