Save checkpoint to temporary directory to handle partial saves during failures (#35580)

Save checkpoint to temporary folder first

Since partial/missing files due to failures throw error during load
This commit is contained in:
SilverSoldier 2025-02-06 19:18:05 +05:30 committed by GitHub
parent 3dd1de39bb
commit e3458af726
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,6 +18,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import contextlib
import copy
import errno
import functools
import glob
import importlib.metadata
@ -3128,31 +3129,41 @@ class Trainer:
self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
checkpoint_dir = os.path.join(run_dir, checkpoint_folder)
with tempfile.TemporaryDirectory(prefix=f"tmp-{PREFIX_CHECKPOINT_DIR}-", dir=run_dir) as output_dir:
self.save_model(output_dir, _internal_call=True)
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
# Save the Trainer state
if self.args.should_save:
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
for cb in [
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]:
cb_name = cb.__class__.__name__
cb_state = cb.state()
if isinstance(self.state.stateful_callbacks[cb_name], list):
self.state.stateful_callbacks[cb_name].append(cb_state)
# Save the Trainer state
if self.args.should_save:
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
for cb in [
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]:
cb_name = cb.__class__.__name__
cb_state = cb.state()
if isinstance(self.state.stateful_callbacks[cb_name], list):
self.state.stateful_callbacks[cb_name].append(cb_state)
else:
self.state.stateful_callbacks[cb_name] = cb_state
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
try:
os.renames(output_dir, checkpoint_dir)
except OSError as e:
if e.errno in [errno.ENOTEMPTY, errno.EEXIST]: # Directory/File already exists
shutil.rmtree(checkpoint_dir)
os.renames(output_dir, checkpoint_dir)
else:
self.state.stateful_callbacks[cb_name] = cb_state
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
raise
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
self._push_from_checkpoint(checkpoint_dir)
# Maybe delete some older checkpoints.
if self.args.should_save: