mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Yell at the user if zero-3 init wasn't performed, but expected to have been done (#32299)
* Test this zach * Test for improper init w/o zero3 * Move back * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Get rid of stars in warning * Make private * Make clear --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
51ab25e293
commit
82efc53513
@ -1470,9 +1470,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# and memory copying it on CPU or each GPU first
|
||||
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
else:
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Flag for if we init with `zero3`, add an attr to the model so we can check downstream for issues
|
||||
model._transformers_zero3_init_used = is_deepspeed_zero3_enabled()
|
||||
|
||||
# restore default dtype if it was modified
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
@ -3802,6 +3806,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# If we init with `zero3`, add an attr to the model so we can check downstream for issues
|
||||
model._transformers_zero3_init_used = is_deepspeed_zero3_enabled() and not is_quantized
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
|
||||
|
@ -100,6 +100,7 @@ from .trainer_pt_utils import (
|
||||
get_model_param_count,
|
||||
get_module_class_from_name,
|
||||
get_parameter_names,
|
||||
is_deepspeed_zero3_enabled,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
@ -435,6 +436,15 @@ class Trainer:
|
||||
)
|
||||
self.model_init = model_init
|
||||
|
||||
# Will reach this branch if the user has
|
||||
# 1. Used `.from_pretrained` or `.from_config` to initialize their model
|
||||
# 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand
|
||||
# New models init such as `MyModel()` will not hit this step
|
||||
if is_deepspeed_zero3_enabled() and not getattr(model, "_transformers_zero3_init_used", True):
|
||||
raise ValueError(
|
||||
"Model was not initialized with `Zero-3` despite being configured for DeepSpeed Zero-3. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!"
|
||||
)
|
||||
|
||||
if model.__class__.__name__ in MODEL_MAPPING_NAMES:
|
||||
raise ValueError(
|
||||
f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
|
||||
|
@ -709,6 +709,31 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
||||
# Relative difference. See the note above how to get identical loss on a small bs
|
||||
self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3)
|
||||
|
||||
def test_missed_zero3_init(self):
|
||||
from transformers import Trainer # noqa
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
model = AutoModel.from_pretrained(T5_TINY)
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./test_missed_zero3_init",
|
||||
deepspeed=self.get_config_dict(ZERO3),
|
||||
)
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Model was not initialized with `Zero-3` despite being configured."
|
||||
):
|
||||
_ = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
)
|
||||
# Now do it properly, triggered from our `TrainingArguments` earlier
|
||||
model = AutoModel.from_pretrained(T5_TINY)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
)
|
||||
assert trainer.is_deepspeed_enabled
|
||||
assert model._transformers_zero3_init_used
|
||||
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
|
||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||
|
Loading…
Reference in New Issue
Block a user