mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> - Adds auto_batch_size finder - Moves training loop to an inner training loop
This commit is contained in:
parent
df735d1317
commit
2fbb237967
@ -84,6 +84,7 @@ jobs:
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||
- run: pip install tensorflow_probability
|
||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
- run: pip install git+https://github.com/huggingface/accelerate
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -122,6 +123,7 @@ jobs:
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||
- run: pip install tensorflow_probability
|
||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
- run: pip install git+https://github.com/huggingface/accelerate
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -154,6 +156,7 @@ jobs:
|
||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
- run: pip install git+https://github.com/huggingface/accelerate
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -191,6 +194,7 @@ jobs:
|
||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
- run: pip install git+https://github.com/huggingface/accelerate
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -222,6 +226,7 @@ jobs:
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
- run: pip install git+https://github.com/huggingface/accelerate
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -258,6 +263,7 @@ jobs:
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
- run: pip install git+https://github.com/huggingface/accelerate
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
1
setup.py
1
setup.py
@ -96,6 +96,7 @@ if stale_egg_info.exists():
|
||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"accelerate>=0.7.1",
|
||||
"black~=22.0",
|
||||
"codecarbon==1.2.0",
|
||||
"cookiecutter==1.7.3",
|
||||
|
@ -3,6 +3,7 @@
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.7.1",
|
||||
"black": "black~=22.0",
|
||||
"codecarbon": "codecarbon==1.2.0",
|
||||
"cookiecutter": "cookiecutter==1.7.3",
|
||||
|
@ -40,6 +40,7 @@ from .integrations import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from .utils import (
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_detectron2_available,
|
||||
@ -238,6 +239,13 @@ def require_git_lfs(test_case):
|
||||
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_rjieba(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||
|
@ -115,6 +115,7 @@ from .trainer_utils import (
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
denumpify_detensorize,
|
||||
find_executable_batch_size,
|
||||
get_last_checkpoint,
|
||||
has_length,
|
||||
number_of_arguments,
|
||||
@ -548,6 +549,9 @@ class Trainer:
|
||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||
|
||||
# Internal variables to keep track of the original batch size
|
||||
self._train_batch_size = args.train_batch_size
|
||||
|
||||
# very last
|
||||
self._memory_tracker.stop_and_update_metrics()
|
||||
|
||||
@ -718,7 +722,7 @@ class Trainer:
|
||||
if self.args.world_size > 1:
|
||||
train_dataset = IterableDatasetShard(
|
||||
train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
batch_size=self._train_batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
@ -736,7 +740,7 @@ class Trainer:
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
batch_size=self._train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
@ -1267,6 +1271,20 @@ class Trainer:
|
||||
self._move_model_to_device(self.model, args.device)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
inner_training_loop = find_executable_batch_size(
|
||||
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
|
||||
)
|
||||
return inner_training_loop(
|
||||
args=args,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
trial=trial,
|
||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||
)
|
||||
|
||||
def _inner_training_loop(
|
||||
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
|
||||
):
|
||||
self._train_batch_size = batch_size
|
||||
# Data loader and number of training steps
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
|
||||
|
@ -36,6 +36,7 @@ from .utils import (
|
||||
is_torch_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_tpu_available,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
@ -355,6 +356,7 @@ class TrainerMemoryTracker:
|
||||
stages = {
|
||||
"__init__": "init",
|
||||
"train": "train",
|
||||
"_inner_training_loop": "train",
|
||||
"evaluate": "eval",
|
||||
"predict": "test",
|
||||
}
|
||||
@ -584,6 +586,37 @@ class ShardedDDPOption(ExplicitEnum):
|
||||
AUTO_WRAP = "auto_wrap"
|
||||
|
||||
|
||||
def find_executable_batch_size(
|
||||
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
||||
CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as
|
||||
its first argument.
|
||||
function (`callable`, *optional*)
|
||||
A function to wrap
|
||||
starting_batch_size (`int`, *optional*)
|
||||
The batch size to try and fit into memory
|
||||
auto_find_batch_size (`bool`, *optional*)
|
||||
If False, will just execute `function`
|
||||
"""
|
||||
if function is None:
|
||||
return functools.partial(
|
||||
find_executable_batch_size,
|
||||
starting_batch_size=starting_batch_size,
|
||||
auto_find_batch_size=auto_find_batch_size,
|
||||
)
|
||||
|
||||
if auto_find_batch_size:
|
||||
requires_backends(find_executable_batch_size, "accelerate")
|
||||
import accelerate.memory_utils as mem_utils
|
||||
|
||||
return mem_utils.find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
|
||||
|
||||
return functools.partial(function, batch_size=starting_batch_size)
|
||||
|
||||
|
||||
class FSDPOption(ExplicitEnum):
|
||||
FULL_SHARD = "full_shard"
|
||||
SHARD_GRAD_OP = "shard_grad_op"
|
||||
|
@ -443,6 +443,9 @@ class TrainingArguments:
|
||||
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
||||
that need inputs, predictions and references for scoring calculation in Metric class.
|
||||
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
|
||||
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
|
||||
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
@ -803,6 +806,13 @@ class TrainingArguments:
|
||||
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
|
||||
)
|
||||
|
||||
auto_find_batch_size: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||
# This needs to happen before any call to self.device or self.n_gpu.
|
||||
|
@ -85,6 +85,7 @@ from .import_utils import (
|
||||
DummyObject,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_coloredlogs_available,
|
||||
|
@ -428,6 +428,10 @@ def is_protobuf_available():
|
||||
return importlib.util.find_spec("google.protobuf") is not None
|
||||
|
||||
|
||||
def is_accelerate_available():
|
||||
return importlib.util.find_spec("accelerate") is not None
|
||||
|
||||
|
||||
def is_tokenizers_available():
|
||||
return importlib.util.find_spec("tokenizers") is not None
|
||||
|
||||
@ -725,6 +729,12 @@ PYCTCDECODE_IMPORT_ERROR = """
|
||||
`pip install pyctcdecode`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
ACCELERATE_IMPORT_ERROR = """
|
||||
{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
|
||||
`pip install accelerate`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
@ -750,6 +760,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
@ -58,6 +59,7 @@ from transformers.testing_utils import (
|
||||
require_torch_bf16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_wandb,
|
||||
@ -1075,6 +1077,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||
|
||||
@slow
|
||||
@require_torch_non_multi_gpu
|
||||
def test_auto_batch_size_finder(self):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
SRC_DIR = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
|
||||
)
|
||||
sys.path.append(SRC_DIR)
|
||||
import run_glue
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
testargs = f"""
|
||||
run_glue.py
|
||||
--model_name_or_path distilbert-base-uncased
|
||||
--task_name mrpc
|
||||
--do_train
|
||||
--do_eval
|
||||
--max_seq_len 128
|
||||
--per_device_train_batch_size 4096
|
||||
--learning_rate 2e-5
|
||||
--num_train_epochs 1
|
||||
--output_dir {tmpdir}
|
||||
--auto_find_batch_size 0
|
||||
""".split()
|
||||
with self.assertRaises(RuntimeError):
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
testargs[-1] = "1"
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||
def test_training_with_resume_from_checkpoint_false(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
|
@ -18,7 +18,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.testing_utils import require_accelerate, require_torch
|
||||
from transformers.trainer_utils import find_executable_batch_size
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
@ -420,3 +421,39 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
|
||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
|
||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size(self):
|
||||
batch_sizes = []
|
||||
|
||||
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
|
||||
def mock_training_loop_function(batch_size):
|
||||
nonlocal batch_sizes
|
||||
batch_sizes.append(batch_size)
|
||||
if batch_size > 16:
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
mock_training_loop_function()
|
||||
self.assertEqual(batch_sizes, [64, 32, 16])
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size_no_search(self):
|
||||
batch_sizes = []
|
||||
|
||||
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
|
||||
def mock_training_loop_function(batch_size):
|
||||
nonlocal batch_sizes
|
||||
batch_sizes.append(batch_size)
|
||||
|
||||
mock_training_loop_function()
|
||||
self.assertEqual(batch_sizes, [64])
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size_with_error(self):
|
||||
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
|
||||
def mock_training_loop_function(batch_size):
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||
|
Loading…
Reference in New Issue
Block a user