From 82efc53513a51660e629c7eca8210af1d67df00b Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 1 Aug 2024 15:18:43 -0400 Subject: [PATCH] Yell at the user if zero-3 init wasn't performed, but expected to have been done (#32299) * Test this zach * Test for improper init w/o zero3 * Move back * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Get rid of stars in warning * Make private * Make clear --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_utils.py | 7 +++++++ src/transformers/trainer.py | 10 ++++++++++ tests/deepspeed/test_deepspeed.py | 25 +++++++++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 09c23e5b741..651d2072825 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1470,9 +1470,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # and memory copying it on CPU or each GPU first with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): model = cls(config, **kwargs) + else: model = cls(config, **kwargs) + # Flag for if we init with `zero3`, add an attr to the model so we can check downstream for issues + model._transformers_zero3_init_used = is_deepspeed_zero3_enabled() + # restore default dtype if it was modified if dtype_orig is not None: torch.set_default_dtype(dtype_orig) @@ -3802,6 +3806,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) + # If we init with `zero3`, add an attr to the model so we can check downstream for issues + model._transformers_zero3_init_used = is_deepspeed_zero3_enabled() and not is_quantized + # make sure we use the model's config since the __init__ call might have copied it config = model.config diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 68ba7babfc5..59f0ed438bf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -100,6 +100,7 @@ from .trainer_pt_utils import ( get_model_param_count, get_module_class_from_name, get_parameter_names, + is_deepspeed_zero3_enabled, nested_concat, nested_detach, nested_numpify, @@ -435,6 +436,15 @@ class Trainer: ) self.model_init = model_init + # Will reach this branch if the user has + # 1. Used `.from_pretrained` or `.from_config` to initialize their model + # 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand + # New models init such as `MyModel()` will not hit this step + if is_deepspeed_zero3_enabled() and not getattr(model, "_transformers_zero3_init_used", True): + raise ValueError( + "Model was not initialized with `Zero-3` despite being configured for DeepSpeed Zero-3. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!" + ) + if model.__class__.__name__ in MODEL_MAPPING_NAMES: raise ValueError( f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 7b50165babf..7b81ba40e47 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -709,6 +709,31 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T # Relative difference. See the note above how to get identical loss on a small bs self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3) + def test_missed_zero3_init(self): + from transformers import Trainer # noqa + + with mockenv_context(**self.dist_env_1_gpu): + model = AutoModel.from_pretrained(T5_TINY) + training_args = TrainingArguments( + output_dir="./test_missed_zero3_init", + deepspeed=self.get_config_dict(ZERO3), + ) + with self.assertRaises( + ValueError, msg="Model was not initialized with `Zero-3` despite being configured." + ): + _ = Trainer( + model=model, + args=training_args, + ) + # Now do it properly, triggered from our `TrainingArguments` earlier + model = AutoModel.from_pretrained(T5_TINY) + trainer = Trainer( + model=model, + args=training_args, + ) + assert trainer.is_deepspeed_enabled + assert model._transformers_zero3_init_used + def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): # adapted from TrainerIntegrationCommon.check_saved_checkpoints file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]