Add support for bitsandbytes (#15622)

* Add initial BNB integration

* fixup! Add initial BNB integration

* Add bnb test decorator

* Update Adamw8bit option name

* Use the full bnb package name

* Overide bnb for all embedding layers

* Fix package name

* Formatting

* Remove unnecessary import

* Update src/transformers/trainer.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename AdamwBNB optimizer option

* Add training test checking that bnb memory utilization is lower

* fix merge

* fix merge; fix + extend new test

* cleanup

* expand bnb

* move all require_* candidates to testing_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Manuel R. Ciosici 2022-04-19 13:01:29 -07:00 committed by GitHub
parent e6d23a4b9b
commit 3104036e7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 194 additions and 29 deletions

View File

@ -31,8 +31,16 @@ from unittest import mock
from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
from .integrations import (
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
)
from .utils import (
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
@ -638,6 +646,36 @@ def require_deepspeed(test_case):
return test_case
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bitsandbytes_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer

View File

@ -867,6 +867,15 @@ class Trainer:
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
for module in self.model.modules():
if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
@ -917,6 +926,14 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB:
try:
from bitsandbytes.optim import Adam8bit
optimizer_cls = Adam8bit
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs

View File

@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"
@dataclass

View File

@ -85,6 +85,7 @@ from .import_utils import (
DummyObject,
_LazyModule,
is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available,
is_datasets_available,
is_detectron2_available,

View File

@ -400,6 +400,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None
def is_faiss_available():
return _faiss_available

View File

@ -17,10 +17,11 @@ import os
import re
import sys
import unittest
from typing import Tuple
from unittest.mock import patch
from parameterized import parameterized
from transformers.integrations import is_fairscale_available
from transformers import AutoModel
from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
@ -28,6 +29,9 @@ from transformers.testing_utils import (
execute_subprocess_async,
get_gpu_count,
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_fairscale,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
@ -36,7 +40,6 @@ from transformers.testing_utils import (
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from transformers.utils import is_apex_available
bindir = os.path.abspath(os.path.dirname(__file__))
@ -49,28 +52,6 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MBART_TINY = "sshleifer/tiny-mbart"
# a candidate for testing_utils
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
# a candidate for testing_utils
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
@require_torch
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(
@ -193,7 +174,7 @@ class TestTrainerExt(TestCasePlus):
self.assertEqual(n_matches, data["n_matches"])
@slow
def test_run_seq2seq_slow(self):
def test_run_seq2seq(self):
output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
@ -218,6 +199,88 @@ class TestTrainerExt(TestCasePlus):
assert "generated_predictions.txt" in contents
assert "predict_results.json" in contents
@slow
@require_bitsandbytes
def test_run_seq2seq_bnb(self):
from transformers.training_args import OptimizerNames
def train_and_return_metrics(optim: str) -> Tuple[int, float]:
from pathlib import Path
extra_args = (
f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict "
"False --adafactor False --log_level debug"
)
output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
model_name=MARIAN_MODEL,
learning_rate=3e-4,
num_train_epochs=1,
distributed=True, # force run in a new process
extra_args_str=extra_args,
do_eval=False,
do_predict=False,
)
# Check metrics
logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history
gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"]
gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"]
loss = logs[0]["train_loss"]
return gpu_peak_mem, gpu_alloc_mem, loss
gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value)
gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value)
gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb
gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb
gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb
# leave this for now if CI gets very different results
# print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
# print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
# print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
# print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
# print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")
self.assertGreater(
gpu_peak_mem_diff_percent,
10, # basically a huge difference - got ~30x on my desktop
"should use very little peak gpu memory with BNB, compared to without it"
f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}",
)
self.assertGreater(
gpu_total_mem_diff_percent,
0.20, # could easily be 0.50, but let's stay on the safe side
"Using BNB should use less total GPU memory than without it"
f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}",
)
self.assertEqual(
loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
)
# Additionally let's test that the absolute gpu memory difference is larger or about the
# same as the expected saving coming from BNB (6 bytes per param)
model = AutoModel.from_pretrained(MARIAN_MODEL)
total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
bnb_saved_bytes = total_numel * 6 # 324MB
self.assertGreater(
gpu_total_mem_diff_bytes,
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
)
def run_trainer(
self,
eval_steps: int,
@ -300,6 +363,8 @@ class TestTrainerExt(TestCasePlus):
{self.examples_dir_str}/pytorch/translation/run_translation.py
""".split()
cmd = [sys.executable] + distributed_args + args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async(cmd, env=self.get_env())
else:
testargs = ["run_translation.py"] + args

View File

@ -65,7 +65,7 @@ from transformers.testing_utils import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_NAME, is_apex_available
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer
@ -1870,6 +1870,7 @@ if is_torch_available():
},
),
]
if is_apex_available():
import apex
@ -1881,6 +1882,17 @@ if is_torch_available():
)
)
if is_bitsandbytes_available():
import bitsandbytes as bnb
optim_test_params.append(
(
OptimizerNames.ADAMW_BNB,
bnb.optim.Adam8bit,
default_adam_kwargs,
)
)
@require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase):
@ -1905,8 +1917,8 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation.
mock = Mock()
modules = {
@ -1930,6 +1942,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_adam8bit(self):
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB,
default_adam_kwargs,
mock.optim.Adam8bit,
)
def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bnb.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
@require_torch
@require_wandb