mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bugs with trainer (#24134)
* fix the deepspeed test failures * apex fix * FSDP save ckpt fix * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
be10092e63
commit
f2b918356c
@ -1749,7 +1749,16 @@ class Trainer:
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
if use_accelerator_prepare:
|
||||
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
if hasattr(self.lr_scheduler, "step"):
|
||||
if self.use_apex:
|
||||
model = self.accelerator.prepare(self.model)
|
||||
else:
|
||||
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
else:
|
||||
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
|
||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
self.model = model
|
||||
@ -2841,6 +2850,7 @@ class Trainer:
|
||||
or self.is_fsdp_enabled
|
||||
):
|
||||
if self.is_fsdp_enabled:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
|
||||
else:
|
||||
state_dict = self.model.state_dict()
|
||||
|
Loading…
Reference in New Issue
Block a user