mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 04:28:26 +06:00
Add the auto_find_batch_size capability from Accelerate into Trainer (#17068)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> - Adds auto_batch_size finder - Moves training loop to an inner training loop
This commit is contained in:
parent
df735d1317
commit
2fbb237967
@ -84,6 +84,7 @@ jobs:
|
|||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
- run: pip install tensorflow_probability
|
- run: pip install tensorflow_probability
|
||||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-{{ checksum "setup.py" }}
|
key: v0.4-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
@ -122,6 +123,7 @@ jobs:
|
|||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
- run: pip install tensorflow_probability
|
- run: pip install tensorflow_probability
|
||||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-{{ checksum "setup.py" }}
|
key: v0.4-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
@ -154,6 +156,7 @@ jobs:
|
|||||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
|
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
|
||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-{{ checksum "setup.py" }}
|
key: v0.4-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
@ -191,6 +194,7 @@ jobs:
|
|||||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
|
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
|
||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-{{ checksum "setup.py" }}
|
key: v0.4-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
@ -222,6 +226,7 @@ jobs:
|
|||||||
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
@ -258,6 +263,7 @@ jobs:
|
|||||||
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
|
||||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
|
||||||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
|
1
setup.py
1
setup.py
@ -96,6 +96,7 @@ if stale_egg_info.exists():
|
|||||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||||
_deps = [
|
_deps = [
|
||||||
"Pillow",
|
"Pillow",
|
||||||
|
"accelerate>=0.7.1",
|
||||||
"black~=22.0",
|
"black~=22.0",
|
||||||
"codecarbon==1.2.0",
|
"codecarbon==1.2.0",
|
||||||
"cookiecutter==1.7.3",
|
"cookiecutter==1.7.3",
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# 2. run `make deps_table_update``
|
# 2. run `make deps_table_update``
|
||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
|
"accelerate": "accelerate>=0.7.1",
|
||||||
"black": "black~=22.0",
|
"black": "black~=22.0",
|
||||||
"codecarbon": "codecarbon==1.2.0",
|
"codecarbon": "codecarbon==1.2.0",
|
||||||
"cookiecutter": "cookiecutter==1.7.3",
|
"cookiecutter": "cookiecutter==1.7.3",
|
||||||
|
@ -40,6 +40,7 @@ from .integrations import (
|
|||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
is_accelerate_available,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_detectron2_available,
|
is_detectron2_available,
|
||||||
@ -238,6 +239,13 @@ def require_git_lfs(test_case):
|
|||||||
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
|
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_accelerate(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_rjieba(test_case):
|
def require_rjieba(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||||
|
@ -115,6 +115,7 @@ from .trainer_utils import (
|
|||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
default_hp_space,
|
||||||
denumpify_detensorize,
|
denumpify_detensorize,
|
||||||
|
find_executable_batch_size,
|
||||||
get_last_checkpoint,
|
get_last_checkpoint,
|
||||||
has_length,
|
has_length,
|
||||||
number_of_arguments,
|
number_of_arguments,
|
||||||
@ -548,6 +549,9 @@ class Trainer:
|
|||||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
# Internal variables to keep track of the original batch size
|
||||||
|
self._train_batch_size = args.train_batch_size
|
||||||
|
|
||||||
# very last
|
# very last
|
||||||
self._memory_tracker.stop_and_update_metrics()
|
self._memory_tracker.stop_and_update_metrics()
|
||||||
|
|
||||||
@ -718,7 +722,7 @@ class Trainer:
|
|||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
train_dataset = IterableDatasetShard(
|
train_dataset = IterableDatasetShard(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=self.args.train_batch_size,
|
batch_size=self._train_batch_size,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
num_processes=self.args.world_size,
|
num_processes=self.args.world_size,
|
||||||
process_index=self.args.process_index,
|
process_index=self.args.process_index,
|
||||||
@ -736,7 +740,7 @@ class Trainer:
|
|||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=self.args.train_batch_size,
|
batch_size=self._train_batch_size,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
@ -1267,6 +1271,20 @@ class Trainer:
|
|||||||
self._move_model_to_device(self.model, args.device)
|
self._move_model_to_device(self.model, args.device)
|
||||||
self.model_wrapped = self.model
|
self.model_wrapped = self.model
|
||||||
|
|
||||||
|
inner_training_loop = find_executable_batch_size(
|
||||||
|
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
|
||||||
|
)
|
||||||
|
return inner_training_loop(
|
||||||
|
args=args,
|
||||||
|
resume_from_checkpoint=resume_from_checkpoint,
|
||||||
|
trial=trial,
|
||||||
|
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _inner_training_loop(
|
||||||
|
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
|
||||||
|
):
|
||||||
|
self._train_batch_size = batch_size
|
||||||
# Data loader and number of training steps
|
# Data loader and number of training steps
|
||||||
train_dataloader = self.get_train_dataloader()
|
train_dataloader = self.get_train_dataloader()
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ from .utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
|
requires_backends,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -355,6 +356,7 @@ class TrainerMemoryTracker:
|
|||||||
stages = {
|
stages = {
|
||||||
"__init__": "init",
|
"__init__": "init",
|
||||||
"train": "train",
|
"train": "train",
|
||||||
|
"_inner_training_loop": "train",
|
||||||
"evaluate": "eval",
|
"evaluate": "eval",
|
||||||
"predict": "test",
|
"predict": "test",
|
||||||
}
|
}
|
||||||
@ -584,6 +586,37 @@ class ShardedDDPOption(ExplicitEnum):
|
|||||||
AUTO_WRAP = "auto_wrap"
|
AUTO_WRAP = "auto_wrap"
|
||||||
|
|
||||||
|
|
||||||
|
def find_executable_batch_size(
|
||||||
|
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
||||||
|
CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as
|
||||||
|
its first argument.
|
||||||
|
function (`callable`, *optional*)
|
||||||
|
A function to wrap
|
||||||
|
starting_batch_size (`int`, *optional*)
|
||||||
|
The batch size to try and fit into memory
|
||||||
|
auto_find_batch_size (`bool`, *optional*)
|
||||||
|
If False, will just execute `function`
|
||||||
|
"""
|
||||||
|
if function is None:
|
||||||
|
return functools.partial(
|
||||||
|
find_executable_batch_size,
|
||||||
|
starting_batch_size=starting_batch_size,
|
||||||
|
auto_find_batch_size=auto_find_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if auto_find_batch_size:
|
||||||
|
requires_backends(find_executable_batch_size, "accelerate")
|
||||||
|
import accelerate.memory_utils as mem_utils
|
||||||
|
|
||||||
|
return mem_utils.find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
|
||||||
|
|
||||||
|
return functools.partial(function, batch_size=starting_batch_size)
|
||||||
|
|
||||||
|
|
||||||
class FSDPOption(ExplicitEnum):
|
class FSDPOption(ExplicitEnum):
|
||||||
FULL_SHARD = "full_shard"
|
FULL_SHARD = "full_shard"
|
||||||
SHARD_GRAD_OP = "shard_grad_op"
|
SHARD_GRAD_OP = "shard_grad_op"
|
||||||
|
@ -443,6 +443,9 @@ class TrainingArguments:
|
|||||||
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
||||||
that need inputs, predictions and references for scoring calculation in Metric class.
|
that need inputs, predictions and references for scoring calculation in Metric class.
|
||||||
|
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
|
||||||
|
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
|
||||||
|
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -803,6 +806,13 @@ class TrainingArguments:
|
|||||||
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
|
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
auto_find_batch_size: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||||
# This needs to happen before any call to self.device or self.n_gpu.
|
# This needs to happen before any call to self.device or self.n_gpu.
|
||||||
|
@ -85,6 +85,7 @@ from .import_utils import (
|
|||||||
DummyObject,
|
DummyObject,
|
||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
|
is_accelerate_available,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_coloredlogs_available,
|
is_coloredlogs_available,
|
||||||
|
@ -428,6 +428,10 @@ def is_protobuf_available():
|
|||||||
return importlib.util.find_spec("google.protobuf") is not None
|
return importlib.util.find_spec("google.protobuf") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_accelerate_available():
|
||||||
|
return importlib.util.find_spec("accelerate") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_tokenizers_available():
|
def is_tokenizers_available():
|
||||||
return importlib.util.find_spec("tokenizers") is not None
|
return importlib.util.find_spec("tokenizers") is not None
|
||||||
|
|
||||||
@ -725,6 +729,12 @@ PYCTCDECODE_IMPORT_ERROR = """
|
|||||||
`pip install pyctcdecode`
|
`pip install pyctcdecode`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# docstyle-ignore
|
||||||
|
ACCELERATE_IMPORT_ERROR = """
|
||||||
|
{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
|
||||||
|
`pip install accelerate`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
BACKENDS_MAPPING = OrderedDict(
|
BACKENDS_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
@ -750,6 +760,7 @@ BACKENDS_MAPPING = OrderedDict(
|
|||||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||||
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
|
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
|
||||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||||
|
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
@ -58,6 +59,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_bf16,
|
require_torch_bf16,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
|
require_torch_non_multi_gpu,
|
||||||
require_torch_tf32,
|
require_torch_tf32,
|
||||||
require_torch_up_to_2_gpus,
|
require_torch_up_to_2_gpus,
|
||||||
require_wandb,
|
require_wandb,
|
||||||
@ -1075,6 +1077,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_non_multi_gpu
|
||||||
|
def test_auto_batch_size_finder(self):
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
SRC_DIR = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
|
||||||
|
)
|
||||||
|
sys.path.append(SRC_DIR)
|
||||||
|
import run_glue
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
testargs = f"""
|
||||||
|
run_glue.py
|
||||||
|
--model_name_or_path distilbert-base-uncased
|
||||||
|
--task_name mrpc
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--max_seq_len 128
|
||||||
|
--per_device_train_batch_size 4096
|
||||||
|
--learning_rate 2e-5
|
||||||
|
--num_train_epochs 1
|
||||||
|
--output_dir {tmpdir}
|
||||||
|
--auto_find_batch_size 0
|
||||||
|
""".split()
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_glue.main()
|
||||||
|
|
||||||
|
testargs[-1] = "1"
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_glue.main()
|
||||||
|
|
||||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||||
def test_training_with_resume_from_checkpoint_false(self):
|
def test_training_with_resume_from_checkpoint_false(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
@ -18,7 +18,8 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.testing_utils import require_accelerate, require_torch
|
||||||
|
from transformers.trainer_utils import find_executable_batch_size
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@ -420,3 +421,39 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
|
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
|
||||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
|
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
def test_executable_batch_size(self):
|
||||||
|
batch_sizes = []
|
||||||
|
|
||||||
|
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
|
||||||
|
def mock_training_loop_function(batch_size):
|
||||||
|
nonlocal batch_sizes
|
||||||
|
batch_sizes.append(batch_size)
|
||||||
|
if batch_size > 16:
|
||||||
|
raise RuntimeError("CUDA out of memory.")
|
||||||
|
|
||||||
|
mock_training_loop_function()
|
||||||
|
self.assertEqual(batch_sizes, [64, 32, 16])
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
def test_executable_batch_size_no_search(self):
|
||||||
|
batch_sizes = []
|
||||||
|
|
||||||
|
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
|
||||||
|
def mock_training_loop_function(batch_size):
|
||||||
|
nonlocal batch_sizes
|
||||||
|
batch_sizes.append(batch_size)
|
||||||
|
|
||||||
|
mock_training_loop_function()
|
||||||
|
self.assertEqual(batch_sizes, [64])
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
def test_executable_batch_size_with_error(self):
|
||||||
|
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
|
||||||
|
def mock_training_loop_function(batch_size):
|
||||||
|
raise RuntimeError("CUDA out of memory.")
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
|
mock_training_loop_function()
|
||||||
|
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user