diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a708d8deb4e..b1a95b43ada 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2251,7 +2251,7 @@ class Trainer: else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: @@ -2304,12 +2304,13 @@ class Trainer: # In case of auto_find_batch_size=True # Remove FSDP wrapping from sub-models. self.model = unwrap_model(self.model, recursive=True) - # configure fsdp plugin for qlora if any - self._fsdp_qlora_plugin_updates() if delay_optimizer_creation: if use_accelerator_prepare: - self.model = self.accelerator.prepare(self.model) + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() + if self.accelerator.mixed_precision != "fp8": + self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare @@ -4187,7 +4188,7 @@ class Trainer: start_time = time.time() model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled or self.is_fsdp_enabled + if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8") else self.accelerator.prepare_model(model, evaluation_mode=True) ) self.model_preparation_time = round(time.time() - start_time, 4)