mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Only disallow DeepSpeed Zero-3 for auto bs finder (#31731)
* Only disallow DeepSpeed * Clean * DeepSpeed! * Add a test for deepspeed
This commit is contained in:
parent
03c12d0d63
commit
6b7d64ac1c
@ -4822,10 +4822,15 @@ class Trainer:
|
|||||||
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
|
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
|
||||||
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
|
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
|
||||||
|
|
||||||
# `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP
|
# `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
|
||||||
if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size:
|
if (
|
||||||
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
|
self.is_deepspeed_enabled
|
||||||
raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.")
|
and self.accelerator.state.deepspeed_plugin.zero_stage == 3
|
||||||
|
and self.args.auto_find_batch_size
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
|
||||||
|
)
|
||||||
|
|
||||||
def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
|
def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
|
||||||
"""
|
"""
|
||||||
|
@ -145,6 +145,21 @@ if is_accelerate_available():
|
|||||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||||
|
|
||||||
|
|
||||||
|
class MockCudaOOMCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Simple callback to simulate CUDA OOM error if
|
||||||
|
the batch size is >= to `batch_size_limit`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size_limit=16):
|
||||||
|
self.batch_size_limit = batch_size_limit
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
|
# simulate OOM on the first step
|
||||||
|
if state.train_batch_size >= self.batch_size_limit:
|
||||||
|
raise RuntimeError("CUDA out of memory.")
|
||||||
|
|
||||||
|
|
||||||
class RegressionDataset:
|
class RegressionDataset:
|
||||||
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -2504,7 +2519,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
run_glue.main()
|
run_glue.main()
|
||||||
|
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
def test_auto_batch_size_with_resume_from_checkpoint_with_deepspeed(self):
|
def test_auto_batch_size_with_deepspeed(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
|
||||||
config = RegressionModelConfig(a=0, b=2)
|
config = RegressionModelConfig(a=0, b=2)
|
||||||
@ -2512,15 +2527,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
class MockCudaOOMCallback(TrainerCallback):
|
for stage in [1, 2]:
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
|
||||||
# simulate OOM on the first step
|
|
||||||
if state.train_batch_size >= 16:
|
|
||||||
raise RuntimeError("CUDA out of memory.")
|
|
||||||
|
|
||||||
deepspeed = {
|
deepspeed = {
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 1,
|
"stage": stage,
|
||||||
},
|
},
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
@ -2530,15 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
tmp_dir,
|
tmp_dir,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
max_steps=2,
|
max_steps=2,
|
||||||
save_steps=1,
|
save_strategy="no",
|
||||||
per_device_train_batch_size=16,
|
per_device_train_batch_size=16,
|
||||||
auto_find_batch_size=True,
|
auto_find_batch_size=True,
|
||||||
deepspeed=deepspeed,
|
deepspeed=deepspeed,
|
||||||
)
|
)
|
||||||
# Note: This can have issues, for now we don't support this functionality
|
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
||||||
# ref: https://github.com/huggingface/transformers/pull/29057
|
trainer.train()
|
||||||
with self.assertRaises(NotImplementedError):
|
self.assertEqual(trainer._train_batch_size, 8)
|
||||||
_ = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
|
||||||
|
|
||||||
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
@ -2548,12 +2557,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
class MockCudaOOMCallback(TrainerCallback):
|
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
|
||||||
# simulate OOM on the first step
|
|
||||||
if state.train_batch_size >= 16:
|
|
||||||
raise RuntimeError("CUDA out of memory.")
|
|
||||||
|
|
||||||
args = RegressionTrainingArguments(
|
args = RegressionTrainingArguments(
|
||||||
tmp_dir,
|
tmp_dir,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user