Add AdEMAMix optimizer (#33682)

* Add AdEMAMix optimizer

* Fix test

* Update tests/trainer/test_trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Matthew Douglas 2024-09-25 13:07:21 -04:00 committed by GitHub
parent 61e98cb957
commit 196d35ccfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 201 additions and 1 deletions

View File

@ -1237,6 +1237,10 @@ class Trainer:
OptimizerNames.ADAMW_8BIT,
OptimizerNames.PAGED_ADAMW,
OptimizerNames.PAGED_ADAMW_8BIT,
OptimizerNames.ADEMAMIX,
OptimizerNames.ADEMAMIX_8BIT,
OptimizerNames.PAGED_ADEMAMIX,
OptimizerNames.PAGED_ADEMAMIX_8BIT,
OptimizerNames.LION,
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
@ -1266,6 +1270,33 @@ class Trainer:
# Above we pass all `adam_kwargs` to the optimizer, here
# we only pass `optim_args` which can be passed by the user.
additional_optim_kwargs = optim_args
elif "ademamix" in args.optim:
if is_bitsandbytes_available() and version.parse(
importlib.metadata.version("bitsandbytes")
) < version.parse("0.44.0"):
raise ValueError(
"The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
"Please install `bitsandbytes` >= 0.44.0."
)
from bitsandbytes.optim import AdEMAMix
optimizer_cls = AdEMAMix
additional_optim_kwargs = {
"betas": (
float(optim_args.get("beta1", args.adam_beta1)),
float(optim_args.get("beta2", args.adam_beta2)),
float(optim_args.get("beta3", 0.9999)),
),
"alpha": float(optim_args.get("alpha", 5.0)),
"eps": float(optim_args.get("eps", args.adam_epsilon)),
}
if "t_alpha" in optim_args:
additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
if "t_beta3" in optim_args:
additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
bnb_kwargs = {"optim_bits": optim_bits}
if "rmsprop" not in args.optim:

View File

@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum):
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
ADEMAMIX = "ademamix"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
ADEMAMIX_8BIT = "ademamix_8bit"
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_ADEMAMIX = "paged_ademamix_32bit"
PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
@ -618,7 +622,7 @@ class TrainingArguments:
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
for a full list of optimizers.
optim_args (`str`, *optional*):
Optional arguments that are supplied to AnyPrecisionAdamW.
Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
group_by_length (`bool`, *optional*, defaults to `False`):
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.

View File

@ -15,6 +15,7 @@
import dataclasses
import gc
import importlib
import json
import math
import os
@ -32,6 +33,7 @@ from unittest.mock import Mock, patch
import numpy as np
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
from packaging import version
from parameterized import parameterized
from requests.exceptions import HTTPError
@ -1091,6 +1093,40 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check that it trains without errors
trainer.train()
@require_bitsandbytes
def test_ademamix_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
# Check that it trains without errors
trainer.train()
@require_bitsandbytes
def test_ademamix_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
# Check that it trains without errors
trainer.train()
@require_bitsandbytes
def test_rmsprop_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
@ -4187,6 +4223,13 @@ if is_torch_available():
"lr": TrainingArguments.learning_rate,
}
default_ademamix_kwargs = {
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999),
"alpha": 5.0,
"eps": TrainingArguments.adam_epsilon,
"lr": TrainingArguments.learning_rate,
}
default_anyprecision_kwargs = {
"use_kahan_summation": False,
"momentum_dtype": torch.float32,
@ -4291,6 +4334,36 @@ if is_torch_available():
)
)
if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"):
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
if is_torchdistx_available():
import torchdistx
@ -4420,6 +4493,62 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
default_adam_kwargs,
)
def test_bnb_ademamix(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_ademamix8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_paged_ademamix(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_paged_ademamix8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_lion(self):
mock = Mock()
modules = {
@ -4503,6 +4632,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_ademamix_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_ademamix8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_ademamix_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_ademamix8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_lion_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")