Deprecates AdamW and adds --optim (#14744)

* Add AdamW deprecation warning

* Add --optim to Trainer

* Update src/transformers/optimization.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/optimization.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/optimization.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/optimization.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/training_args.py

* fix style

* fix

* Regroup adamws together

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Change --adafactor to --optim adafactor

* Use Enum for optimizer values

* fixup! Change --adafactor to --optim adafactor

* fixup! Change --adafactor to --optim adafactor

* fixup! Change --adafactor to --optim adafactor

* fixup! Use Enum for optimizer values

* Improved documentation for --adafactor

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Add mention of no_deprecation_warning

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename OptimizerOptions to OptimizerNames

* Use choices for --optim

* Move optimizer selection code to a function and add a unit test

* Change optimizer names

* Rename method

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename method

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Remove TODO comment

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename variable

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename variable

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename function

* Rename variable

* Parameterize the tests for supported optimizers

* Refactor

* Attempt to make tests pass on CircleCI

* Add a test with apex

* rework to add apex to parameterized; add actual train test

* fix import when torch is not available

* fix optim_test_params when torch is not available

* fix optim_test_params when torch is not available

* re-org

* small re-org

* fix test_fused_adam_no_apex

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove .value from OptimizerNames

* Rename optimizer strings s|--adam_|--adamw_|

* Also rename Enum options

* small fix

* Fix instantiation of OptimizerNames. Remove redundant test

* Use ExplicitEnum instead of Enum

* Add unit test with string optimizer

* Change optimizer default to string value

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Manuel R. Ciosici 2022-01-13 11:14:51 -05:00 committed by GitHub
parent 762416ffa8
commit 7b83feb50a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 183 additions and 15 deletions

View File

@ -15,6 +15,7 @@
"""PyTorch optimization for BERT model.""" """PyTorch optimization for BERT model."""
import math import math
import warnings
from typing import Callable, Iterable, Optional, Tuple, Union from typing import Callable, Iterable, Optional, Tuple, Union
import torch import torch
@ -287,6 +288,8 @@ class AdamW(Optimizer):
Decoupled weight decay to apply. Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`): correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
""" """
def __init__( def __init__(
@ -297,7 +300,14 @@ class AdamW(Optimizer):
eps: float = 1e-6, eps: float = 1e-6,
weight_decay: float = 0.0, weight_decay: float = 0.0,
correct_bias: bool = True, correct_bias: bool = True,
no_deprecation_warning: bool = False,
): ):
if not no_deprecation_warning:
warnings.warn(
"This implementation of AdamW is deprecated and will be removed in a future version. Use the"
"PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0: if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")

View File

@ -77,7 +77,7 @@ from .file_utils import (
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, AdamW, get_scheduler from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
@ -128,7 +128,7 @@ from .trainer_utils import (
set_seed, set_seed,
speed_metrics, speed_metrics,
) )
from .training_args import ParallelMode, TrainingArguments from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import logging from .utils import logging
@ -819,17 +819,9 @@ class Trainer:
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor: optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp == ShardedDDPOption.SIMPLE: if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS( self.optimizer = OSS(
params=optimizer_grouped_parameters, params=optimizer_grouped_parameters,
@ -844,6 +836,46 @@ class Trainer:
return self.optimizer return self.optimizer
@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
"""
Returns the optimizer class and optimizer parameters based on the training arguments.
Args:
args (`transformers.training_args.TrainingArguments`):
The training arguments for the training session.
"""
optimizer_kwargs = {"lr": args.learning_rate}
adam_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}
if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim == OptimizerNames.ADAMW_HF:
from .optimization import AdamW
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_TORCH:
from torch.optim import AdamW
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
try:
from apex.optimizers import FusedAdam
optimizer_cls = FusedAdam
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
""" """
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or

View File

@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption from .debug_utils import DebugOption
from .file_utils import ( from .file_utils import (
ExplicitEnum,
cached_property, cached_property,
get_full_repo_name, get_full_repo_name,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
@ -69,6 +70,17 @@ def default_logdir() -> str:
return os.path.join("runs", current_time + "_" + socket.gethostname()) return os.path.join("runs", current_time + "_" + socket.gethostname())
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
"""
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
""" """
@ -327,8 +339,10 @@ class TrainingArguments:
- `"tpu_metrics_debug"`: print debug metrics on TPU - `"tpu_metrics_debug"`: print debug metrics on TPU
The options should be separated by whitespaces. The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor.
adafactor (`bool`, *optional*, defaults to `False`): adafactor (`bool`, *optional*, defaults to `False`):
Whether or not to use the [`Adafactor`] optimizer instead of [`AdamW`]. This argument is deprecated. Use `--optim adafactor` instead.
group_by_length (`bool`, *optional*, defaults to `False`): group_by_length (`bool`, *optional*, defaults to `False`):
Whether or not to group together samples of roughly the same length in the training dataset (to minimize Whether or not to group together samples of roughly the same length in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding. padding applied and be more efficient). Only useful if applying dynamic padding.
@ -641,6 +655,10 @@ class TrainingArguments:
label_smoothing_factor: float = field( label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
) )
optim: OptimizerNames = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
group_by_length: bool = field( group_by_length: bool = field(
default=False, default=False,
@ -809,6 +827,15 @@ class TrainingArguments:
) )
if not (self.sharded_ddp == "" or not self.sharded_ddp): if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16") raise ValueError("sharded_ddp is not supported with bf16")
self.optim = OptimizerNames(self.optim)
if self.adafactor:
warnings.warn(
"`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim adafactor` instead",
FutureWarning,
)
self.optim = OptimizerNames.ADAFACTOR
if ( if (
is_torch_available() is_torch_available()
and self.device.type != "cuda" and self.device.type != "cuda"

View File

@ -23,10 +23,12 @@ import subprocess
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch
import numpy as np import numpy as np
from huggingface_hub import Repository, delete_repo, login from huggingface_hub import Repository, delete_repo, login
from parameterized import parameterized
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@ -36,7 +38,7 @@ from transformers import (
is_torch_available, is_torch_available,
logging, logging,
) )
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME, is_apex_available
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING, ENDPOINT_STAGING,
PASS, PASS,
@ -61,6 +63,7 @@ from transformers.testing_utils import (
slow, slow,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils.hp_naming import TrialShortNamer from transformers.utils.hp_naming import TrialShortNamer
@ -69,6 +72,7 @@ if is_torch_available():
from torch import nn from torch import nn
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
import transformers.optimization
from transformers import ( from transformers import (
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
EarlyStoppingCallback, EarlyStoppingCallback,
@ -1711,3 +1715,98 @@ class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search( trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4 direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
) )
optim_test_params = []
if is_torch_available():
default_adam_kwargs = {
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
"eps": TrainingArguments.adam_epsilon,
"lr": TrainingArguments.learning_rate,
}
optim_test_params = [
(
OptimizerNames.ADAMW_HF,
transformers.optimization.AdamW,
default_adam_kwargs,
),
(
OptimizerNames.ADAMW_HF.value,
transformers.optimization.AdamW,
default_adam_kwargs,
),
(
OptimizerNames.ADAMW_TORCH,
torch.optim.AdamW,
default_adam_kwargs,
),
(
OptimizerNames.ADAFACTOR,
transformers.optimization.Adafactor,
{
"scale_parameter": False,
"relative_step": False,
"lr": TrainingArguments.learning_rate,
},
),
]
if is_apex_available():
import apex
optim_test_params.append(
(
OptimizerNames.ADAMW_APEX_FUSED,
apex.optimizers.FusedAdam,
default_adam_kwargs,
)
)
@require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase):
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls):
args = TrainingArguments(optim=optim, output_dir="None")
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
self.assertEqual(expected_cls, actual_cls)
self.assertIsNotNone(optim_kwargs)
for p, v in mandatory_kwargs.items():
self.assertTrue(p in optim_kwargs)
actual_v = optim_kwargs[p]
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
@parameterized.expand(optim_test_params, skip_on_empty=True)
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):
# exercises all the valid --optim options
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls)
trainer = get_regression_trainer(optim=name)
trainer.train()
def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation.
mock = Mock()
modules = {
"apex": mock,
"apex.optimizers": mock.optimizers,
"apex.optimizers.FusedAdam": mock.optimizers.FusedAdam,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_APEX_FUSED,
default_adam_kwargs,
mock.optimizers.FusedAdam,
)
def test_fused_adam_no_apex(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None")
# Pretend that apex does not exist, even if installed. By setting apex to None, importing
# apex will fail even if apex is installed.
with patch.dict("sys.modules", {"apex.optimizers": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)