mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
parent
6009642459
commit
7237b3ecfc
@ -2251,7 +2251,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||||
|
|
||||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||||
|
|
||||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||||
if self._created_lr_scheduler:
|
if self._created_lr_scheduler:
|
||||||
@ -2304,12 +2304,13 @@ class Trainer:
|
|||||||
# In case of auto_find_batch_size=True
|
# In case of auto_find_batch_size=True
|
||||||
# Remove FSDP wrapping from sub-models.
|
# Remove FSDP wrapping from sub-models.
|
||||||
self.model = unwrap_model(self.model, recursive=True)
|
self.model = unwrap_model(self.model, recursive=True)
|
||||||
# configure fsdp plugin for qlora if any
|
|
||||||
self._fsdp_qlora_plugin_updates()
|
|
||||||
|
|
||||||
if delay_optimizer_creation:
|
if delay_optimizer_creation:
|
||||||
if use_accelerator_prepare:
|
if use_accelerator_prepare:
|
||||||
self.model = self.accelerator.prepare(self.model)
|
# configure fsdp plugin for qlora if any
|
||||||
|
self._fsdp_qlora_plugin_updates()
|
||||||
|
if self.accelerator.mixed_precision != "fp8":
|
||||||
|
self.model = self.accelerator.prepare(self.model)
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
# prepare using `accelerator` prepare
|
# prepare using `accelerator` prepare
|
||||||
@ -4187,7 +4188,7 @@ class Trainer:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
model = (
|
model = (
|
||||||
self.accelerator.prepare(model)
|
self.accelerator.prepare(model)
|
||||||
if self.is_deepspeed_enabled or self.is_fsdp_enabled
|
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
|
||||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||||
)
|
)
|
||||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||||
|
Loading…
Reference in New Issue
Block a user