mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Deprecates AdamW and adds --optim
(#14744)
* Add AdamW deprecation warning * Add --optim to Trainer * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py * fix style * fix * Regroup adamws together Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Change --adafactor to --optim adafactor * Use Enum for optimizer values * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Change --adafactor to --optim adafactor * fixup! Use Enum for optimizer values * Improved documentation for --adafactor Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Add mention of no_deprecation_warning Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename OptimizerOptions to OptimizerNames * Use choices for --optim * Move optimizer selection code to a function and add a unit test * Change optimizer names * Rename method Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename method Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Remove TODO comment Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename variable Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename variable Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename function * Rename variable * Parameterize the tests for supported optimizers * Refactor * Attempt to make tests pass on CircleCI * Add a test with apex * rework to add apex to parameterized; add actual train test * fix import when torch is not available * fix optim_test_params when torch is not available * fix optim_test_params when torch is not available * re-org * small re-org * fix test_fused_adam_no_apex * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove .value from OptimizerNames * Rename optimizer strings s|--adam_|--adamw_| * Also rename Enum options * small fix * Fix instantiation of OptimizerNames. Remove redundant test * Use ExplicitEnum instead of Enum * Add unit test with string optimizer * Change optimizer default to string value Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
762416ffa8
commit
7b83feb50a
@ -15,6 +15,7 @@
|
|||||||
"""PyTorch optimization for BERT model."""
|
"""PyTorch optimization for BERT model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
from typing import Callable, Iterable, Optional, Tuple, Union
|
from typing import Callable, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -287,6 +288,8 @@ class AdamW(Optimizer):
|
|||||||
Decoupled weight decay to apply.
|
Decoupled weight decay to apply.
|
||||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||||
|
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||||
|
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -297,7 +300,14 @@ class AdamW(Optimizer):
|
|||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
correct_bias: bool = True,
|
correct_bias: bool = True,
|
||||||
|
no_deprecation_warning: bool = False,
|
||||||
):
|
):
|
||||||
|
if not no_deprecation_warning:
|
||||||
|
warnings.warn(
|
||||||
|
"This implementation of AdamW is deprecated and will be removed in a future version. Use the"
|
||||||
|
"PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
require_version("torch>=1.5.0") # add_ with alpha
|
require_version("torch>=1.5.0") # add_ with alpha
|
||||||
if lr < 0.0:
|
if lr < 0.0:
|
||||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||||
|
@ -77,7 +77,7 @@ from .file_utils import (
|
|||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||||
from .optimization import Adafactor, AdamW, get_scheduler
|
from .optimization import Adafactor, get_scheduler
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
@ -128,7 +128,7 @@ from .trainer_utils import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
speed_metrics,
|
speed_metrics,
|
||||||
)
|
)
|
||||||
from .training_args import ParallelMode, TrainingArguments
|
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@ -819,17 +819,9 @@ class Trainer:
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer_cls = Adafactor if self.args.adafactor else AdamW
|
|
||||||
if self.args.adafactor:
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
||||||
optimizer_cls = Adafactor
|
|
||||||
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
|
|
||||||
else:
|
|
||||||
optimizer_cls = AdamW
|
|
||||||
optimizer_kwargs = {
|
|
||||||
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
|
||||||
"eps": self.args.adam_epsilon,
|
|
||||||
}
|
|
||||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
|
||||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||||
self.optimizer = OSS(
|
self.optimizer = OSS(
|
||||||
params=optimizer_grouped_parameters,
|
params=optimizer_grouped_parameters,
|
||||||
@ -844,6 +836,46 @@ class Trainer:
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
Returns the optimizer class and optimizer parameters based on the training arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (`transformers.training_args.TrainingArguments`):
|
||||||
|
The training arguments for the training session.
|
||||||
|
|
||||||
|
"""
|
||||||
|
optimizer_kwargs = {"lr": args.learning_rate}
|
||||||
|
adam_kwargs = {
|
||||||
|
"betas": (args.adam_beta1, args.adam_beta2),
|
||||||
|
"eps": args.adam_epsilon,
|
||||||
|
}
|
||||||
|
if args.optim == OptimizerNames.ADAFACTOR:
|
||||||
|
optimizer_cls = Adafactor
|
||||||
|
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||||
|
elif args.optim == OptimizerNames.ADAMW_HF:
|
||||||
|
from .optimization import AdamW
|
||||||
|
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif args.optim == OptimizerNames.ADAMW_TORCH:
|
||||||
|
from torch.optim import AdamW
|
||||||
|
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
|
||||||
|
try:
|
||||||
|
from apex.optimizers import FusedAdam
|
||||||
|
|
||||||
|
optimizer_cls = FusedAdam
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||||
|
return optimizer_cls, optimizer_kwargs
|
||||||
|
|
||||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
||||||
"""
|
"""
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from .debug_utils import DebugOption
|
from .debug_utils import DebugOption
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
|
ExplicitEnum,
|
||||||
cached_property,
|
cached_property,
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
@ -69,6 +70,17 @@ def default_logdir() -> str:
|
|||||||
return os.path.join("runs", current_time + "_" + socket.gethostname())
|
return os.path.join("runs", current_time + "_" + socket.gethostname())
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerNames(ExplicitEnum):
|
||||||
|
"""
|
||||||
|
Stores the acceptable string identifiers for optimizers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ADAMW_HF = "adamw_hf"
|
||||||
|
ADAMW_TORCH = "adamw_torch"
|
||||||
|
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||||
|
ADAFACTOR = "adafactor"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
"""
|
"""
|
||||||
@ -327,8 +339,10 @@ class TrainingArguments:
|
|||||||
- `"tpu_metrics_debug"`: print debug metrics on TPU
|
- `"tpu_metrics_debug"`: print debug metrics on TPU
|
||||||
|
|
||||||
The options should be separated by whitespaces.
|
The options should be separated by whitespaces.
|
||||||
|
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
|
||||||
|
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor.
|
||||||
adafactor (`bool`, *optional*, defaults to `False`):
|
adafactor (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to use the [`Adafactor`] optimizer instead of [`AdamW`].
|
This argument is deprecated. Use `--optim adafactor` instead.
|
||||||
group_by_length (`bool`, *optional*, defaults to `False`):
|
group_by_length (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
|
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
|
||||||
padding applied and be more efficient). Only useful if applying dynamic padding.
|
padding applied and be more efficient). Only useful if applying dynamic padding.
|
||||||
@ -641,6 +655,10 @@ class TrainingArguments:
|
|||||||
label_smoothing_factor: float = field(
|
label_smoothing_factor: float = field(
|
||||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||||
)
|
)
|
||||||
|
optim: OptimizerNames = field(
|
||||||
|
default="adamw_hf",
|
||||||
|
metadata={"help": "The optimizer to use."},
|
||||||
|
)
|
||||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
group_by_length: bool = field(
|
group_by_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@ -809,6 +827,15 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||||
raise ValueError("sharded_ddp is not supported with bf16")
|
raise ValueError("sharded_ddp is not supported with bf16")
|
||||||
|
|
||||||
|
self.optim = OptimizerNames(self.optim)
|
||||||
|
if self.adafactor:
|
||||||
|
warnings.warn(
|
||||||
|
"`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim adafactor` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.optim = OptimizerNames.ADAFACTOR
|
||||||
|
|
||||||
if (
|
if (
|
||||||
is_torch_available()
|
is_torch_available()
|
||||||
and self.device.type != "cuda"
|
and self.device.type != "cuda"
|
||||||
|
@ -23,10 +23,12 @@ import subprocess
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import Repository, delete_repo, login
|
from huggingface_hub import Repository, delete_repo, login
|
||||||
|
from parameterized import parameterized
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -36,7 +38,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME, is_apex_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
ENDPOINT_STAGING,
|
ENDPOINT_STAGING,
|
||||||
PASS,
|
PASS,
|
||||||
@ -61,6 +63,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
from transformers.utils.hp_naming import TrialShortNamer
|
from transformers.utils.hp_naming import TrialShortNamer
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +72,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
|
import transformers.optimization
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
@ -1711,3 +1715,98 @@ class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
|
|||||||
trainer.hyperparameter_search(
|
trainer.hyperparameter_search(
|
||||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
|
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
optim_test_params = []
|
||||||
|
if is_torch_available():
|
||||||
|
default_adam_kwargs = {
|
||||||
|
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
|
||||||
|
"eps": TrainingArguments.adam_epsilon,
|
||||||
|
"lr": TrainingArguments.learning_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
optim_test_params = [
|
||||||
|
(
|
||||||
|
OptimizerNames.ADAMW_HF,
|
||||||
|
transformers.optimization.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
OptimizerNames.ADAMW_HF.value,
|
||||||
|
transformers.optimization.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
OptimizerNames.ADAMW_TORCH,
|
||||||
|
torch.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
OptimizerNames.ADAFACTOR,
|
||||||
|
transformers.optimization.Adafactor,
|
||||||
|
{
|
||||||
|
"scale_parameter": False,
|
||||||
|
"relative_step": False,
|
||||||
|
"lr": TrainingArguments.learning_rate,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
if is_apex_available():
|
||||||
|
import apex
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
OptimizerNames.ADAMW_APEX_FUSED,
|
||||||
|
apex.optimizers.FusedAdam,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||||
|
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls):
|
||||||
|
args = TrainingArguments(optim=optim, output_dir="None")
|
||||||
|
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
self.assertEqual(expected_cls, actual_cls)
|
||||||
|
self.assertIsNotNone(optim_kwargs)
|
||||||
|
|
||||||
|
for p, v in mandatory_kwargs.items():
|
||||||
|
self.assertTrue(p in optim_kwargs)
|
||||||
|
actual_v = optim_kwargs[p]
|
||||||
|
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
|
||||||
|
|
||||||
|
@parameterized.expand(optim_test_params, skip_on_empty=True)
|
||||||
|
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
|
||||||
|
# exercises all the valid --optim options
|
||||||
|
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls)
|
||||||
|
|
||||||
|
trainer = get_regression_trainer(optim=name)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
def test_fused_adam(self):
|
||||||
|
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
|
||||||
|
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
|
||||||
|
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||||
|
# the test to run without requiring an apex installation.
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"apex": mock,
|
||||||
|
"apex.optimizers": mock.optimizers,
|
||||||
|
"apex.optimizers.FusedAdam": mock.optimizers.FusedAdam,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
OptimizerNames.ADAMW_APEX_FUSED,
|
||||||
|
default_adam_kwargs,
|
||||||
|
mock.optimizers.FusedAdam,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fused_adam_no_apex(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that apex does not exist, even if installed. By setting apex to None, importing
|
||||||
|
# apex will fail even if apex is installed.
|
||||||
|
with patch.dict("sys.modules", {"apex.optimizers": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user