diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1002ceb0cfe..423eaa69fd1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4822,10 +4822,15 @@ class Trainer: wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") - # `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP - if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size: - wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" - raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.") + # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3 + if ( + self.is_deepspeed_enabled + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 + and self.args.auto_find_batch_size + ): + raise ValueError( + "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP" + ) def propagate_args_to_deepspeed(self, auto_find_batch_size=False): """ diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a557e0f80b7..49213f19187 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -145,6 +145,21 @@ if is_accelerate_available(): PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" +class MockCudaOOMCallback(TrainerCallback): + """ + Simple callback to simulate CUDA OOM error if + the batch size is >= to `batch_size_limit`. + """ + + def __init__(self, batch_size_limit=16): + self.batch_size_limit = batch_size_limit + + def on_step_end(self, args, state, control, **kwargs): + # simulate OOM on the first step + if state.train_batch_size >= self.batch_size_limit: + raise RuntimeError("CUDA out of memory.") + + class RegressionDataset: def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): np.random.seed(seed) @@ -2504,7 +2519,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): run_glue.main() @require_deepspeed - def test_auto_batch_size_with_resume_from_checkpoint_with_deepspeed(self): + def test_auto_batch_size_with_deepspeed(self): train_dataset = RegressionDataset(length=128) config = RegressionModelConfig(a=0, b=2) @@ -2512,33 +2527,27 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): tmp_dir = self.get_auto_remove_tmp_dir() - class MockCudaOOMCallback(TrainerCallback): - def on_step_end(self, args, state, control, **kwargs): - # simulate OOM on the first step - if state.train_batch_size >= 16: - raise RuntimeError("CUDA out of memory.") - - deepspeed = { - "zero_optimization": { - "stage": 1, - }, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - } + for stage in [1, 2]: + deepspeed = { + "zero_optimization": { + "stage": stage, + }, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + } args = RegressionTrainingArguments( tmp_dir, do_train=True, max_steps=2, - save_steps=1, + save_strategy="no", per_device_train_batch_size=16, auto_find_batch_size=True, deepspeed=deepspeed, ) - # Note: This can have issues, for now we don't support this functionality - # ref: https://github.com/huggingface/transformers/pull/29057 - with self.assertRaises(NotImplementedError): - _ = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) + trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) + trainer.train() + self.assertEqual(trainer._train_batch_size, 8) def test_auto_batch_size_with_resume_from_checkpoint(self): train_dataset = RegressionDataset(length=128) @@ -2548,12 +2557,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): tmp_dir = self.get_auto_remove_tmp_dir() - class MockCudaOOMCallback(TrainerCallback): - def on_step_end(self, args, state, control, **kwargs): - # simulate OOM on the first step - if state.train_batch_size >= 16: - raise RuntimeError("CUDA out of memory.") - args = RegressionTrainingArguments( tmp_dir, do_train=True,