mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
parent
6009642459
commit
7237b3ecfc
@ -2251,7 +2251,7 @@ class Trainer:
|
||||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
@ -2304,12 +2304,13 @@ class Trainer:
|
||||
# In case of auto_find_batch_size=True
|
||||
# Remove FSDP wrapping from sub-models.
|
||||
self.model = unwrap_model(self.model, recursive=True)
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
|
||||
if delay_optimizer_creation:
|
||||
if use_accelerator_prepare:
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
if self.accelerator.mixed_precision != "fp8":
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
@ -4187,7 +4188,7 @@ class Trainer:
|
||||
start_time = time.time()
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled or self.is_fsdp_enabled
|
||||
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||
|
Loading…
Reference in New Issue
Block a user