mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Add support for bitsandbytes (#15622)
* Add initial BNB integration * fixup! Add initial BNB integration * Add bnb test decorator * Update Adamw8bit option name * Use the full bnb package name * Overide bnb for all embedding layers * Fix package name * Formatting * Remove unnecessary import * Update src/transformers/trainer.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename AdamwBNB optimizer option * Add training test checking that bnb memory utilization is lower * fix merge * fix merge; fix + extend new test * cleanup * expand bnb * move all require_* candidates to testing_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
e6d23a4b9b
commit
3104036e7f
@ -31,8 +31,16 @@ from unittest import mock
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from .deepspeed import is_deepspeed_available
|
||||
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
|
||||
from .integrations import (
|
||||
is_fairscale_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_sigopt_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from .utils import (
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_detectron2_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
@ -638,6 +646,36 @@ def require_deepspeed(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_fairscale(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires fairscale
|
||||
"""
|
||||
if not is_fairscale_available():
|
||||
return unittest.skip("test requires fairscale")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_apex(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires apex
|
||||
"""
|
||||
if not is_apex_available():
|
||||
return unittest.skip("test requires apex")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator for bits and bytes (bnb) dependency
|
||||
"""
|
||||
if not is_bitsandbytes_available():
|
||||
return unittest.skip("test requires bnb")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_phonemizer(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires phonemizer
|
||||
|
@ -867,6 +867,15 @@ class Trainer:
|
||||
)
|
||||
else:
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
@ -917,6 +926,14 @@ class Trainer:
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
except ImportError:
|
||||
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
|
||||
elif args.optim == OptimizerNames.ADAMW_BNB:
|
||||
try:
|
||||
from bitsandbytes.optim import Adam8bit
|
||||
|
||||
optimizer_cls = Adam8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
except ImportError:
|
||||
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
|
||||
else:
|
||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum):
|
||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||
ADAFACTOR = "adafactor"
|
||||
ADAMW_BNB = "adamw_bnb_8bit"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -85,6 +85,7 @@ from .import_utils import (
|
||||
DummyObject,
|
||||
_LazyModule,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_coloredlogs_available,
|
||||
is_datasets_available,
|
||||
is_detectron2_available,
|
||||
|
@ -400,6 +400,10 @@ def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def is_bitsandbytes_available():
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
|
||||
|
||||
def is_faiss_available():
|
||||
return _faiss_available
|
||||
|
||||
|
@ -17,10 +17,11 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers import AutoModel
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
ExtendSysPath,
|
||||
@ -28,6 +29,9 @@ from transformers.testing_utils import (
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_apex,
|
||||
require_bitsandbytes,
|
||||
require_fairscale,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
@ -36,7 +40,6 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import is_apex_available
|
||||
|
||||
|
||||
bindir = os.path.abspath(os.path.dirname(__file__))
|
||||
@ -49,28 +52,6 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_fairscale(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires fairscale
|
||||
"""
|
||||
if not is_fairscale_available():
|
||||
return unittest.skip("test requires fairscale")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_apex(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires apex
|
||||
"""
|
||||
if not is_apex_available():
|
||||
return unittest.skip("test requires apex")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(
|
||||
@ -193,7 +174,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
self.assertEqual(n_matches, data["n_matches"])
|
||||
|
||||
@slow
|
||||
def test_run_seq2seq_slow(self):
|
||||
def test_run_seq2seq(self):
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=2,
|
||||
max_len=128,
|
||||
@ -218,6 +199,88 @@ class TestTrainerExt(TestCasePlus):
|
||||
assert "generated_predictions.txt" in contents
|
||||
assert "predict_results.json" in contents
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_run_seq2seq_bnb(self):
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
def train_and_return_metrics(optim: str) -> Tuple[int, float]:
|
||||
from pathlib import Path
|
||||
|
||||
extra_args = (
|
||||
f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict "
|
||||
"False --adafactor False --log_level debug"
|
||||
)
|
||||
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=2,
|
||||
max_len=128,
|
||||
model_name=MARIAN_MODEL,
|
||||
learning_rate=3e-4,
|
||||
num_train_epochs=1,
|
||||
distributed=True, # force run in a new process
|
||||
extra_args_str=extra_args,
|
||||
do_eval=False,
|
||||
do_predict=False,
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history
|
||||
gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"]
|
||||
gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"]
|
||||
|
||||
loss = logs[0]["train_loss"]
|
||||
return gpu_peak_mem, gpu_alloc_mem, loss
|
||||
|
||||
gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value)
|
||||
gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value)
|
||||
|
||||
gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
|
||||
gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb
|
||||
|
||||
gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
|
||||
gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb
|
||||
|
||||
gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
|
||||
gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb
|
||||
|
||||
# leave this for now if CI gets very different results
|
||||
# print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
|
||||
# print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
|
||||
# print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
|
||||
# print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
|
||||
# print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")
|
||||
|
||||
self.assertGreater(
|
||||
gpu_peak_mem_diff_percent,
|
||||
10, # basically a huge difference - got ~30x on my desktop
|
||||
"should use very little peak gpu memory with BNB, compared to without it"
|
||||
f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}",
|
||||
)
|
||||
|
||||
self.assertGreater(
|
||||
gpu_total_mem_diff_percent,
|
||||
0.20, # could easily be 0.50, but let's stay on the safe side
|
||||
"Using BNB should use less total GPU memory than without it"
|
||||
f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
|
||||
)
|
||||
|
||||
# Additionally let's test that the absolute gpu memory difference is larger or about the
|
||||
# same as the expected saving coming from BNB (6 bytes per param)
|
||||
model = AutoModel.from_pretrained(MARIAN_MODEL)
|
||||
total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
||||
bnb_saved_bytes = total_numel * 6 # 324MB
|
||||
|
||||
self.assertGreater(
|
||||
gpu_total_mem_diff_bytes,
|
||||
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
|
||||
f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
|
||||
)
|
||||
|
||||
def run_trainer(
|
||||
self,
|
||||
eval_steps: int,
|
||||
@ -300,6 +363,8 @@ class TestTrainerExt(TestCasePlus):
|
||||
{self.examples_dir_str}/pytorch/translation/run_translation.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
else:
|
||||
testargs = ["run_translation.py"] + args
|
||||
|
@ -65,7 +65,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import WEIGHTS_NAME, is_apex_available
|
||||
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@ -1870,6 +1870,7 @@ if is_torch_available():
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
if is_apex_available():
|
||||
import apex
|
||||
|
||||
@ -1881,6 +1882,17 @@ if is_torch_available():
|
||||
)
|
||||
)
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
OptimizerNames.ADAMW_BNB,
|
||||
bnb.optim.Adam8bit,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
@ -1905,8 +1917,8 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
|
||||
def test_fused_adam(self):
|
||||
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
|
||||
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
|
||||
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||
# the test to run without requiring an apex installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
@ -1930,6 +1942,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_bnb_adam8bit(self):
|
||||
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
|
||||
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
|
||||
# the test to run without requiring a bnb installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
OptimizerNames.ADAMW_BNB,
|
||||
default_adam_kwargs,
|
||||
mock.optim.Adam8bit,
|
||||
)
|
||||
|
||||
def test_bnb_adam8bit_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
|
||||
|
||||
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||
# bnb will fail even if bnb is installed.
|
||||
with patch.dict("sys.modules", {"bnb.optim": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_wandb
|
||||
|
Loading…
Reference in New Issue
Block a user