Fix FSDP no longer working (#35212)

Fix FSDP failing
This commit is contained in:
Zach Mueller 2024-12-13 13:20:51 -05:00 committed by GitHub
parent 6009642459
commit 7237b3ecfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2251,7 +2251,7 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
@ -2304,12 +2304,13 @@ class Trainer:
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
@ -4187,7 +4188,7 @@ class Trainer:
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled or self.is_fsdp_enabled
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)