mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Save checkpoint to temporary directory to handle partial saves during failures (#35580)
Save checkpoint to temporary folder first Since partial/missing files due to failures throw error during load
This commit is contained in:
parent
3dd1de39bb
commit
e3458af726
@ -18,6 +18,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import errno
|
||||
import functools
|
||||
import glob
|
||||
import importlib.metadata
|
||||
@ -3128,31 +3129,41 @@ class Trainer:
|
||||
self.store_flos()
|
||||
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
checkpoint_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
with tempfile.TemporaryDirectory(prefix=f"tmp-{PREFIX_CHECKPOINT_DIR}-", dir=run_dir) as output_dir:
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
|
||||
if not self.args.save_only_model:
|
||||
# Save optimizer and scheduler
|
||||
self._save_optimizer_and_scheduler(output_dir)
|
||||
# Save RNG state
|
||||
self._save_rng_state(output_dir)
|
||||
if not self.args.save_only_model:
|
||||
# Save optimizer and scheduler
|
||||
self._save_optimizer_and_scheduler(output_dir)
|
||||
# Save RNG state
|
||||
self._save_rng_state(output_dir)
|
||||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
|
||||
for cb in [
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
]:
|
||||
cb_name = cb.__class__.__name__
|
||||
cb_state = cb.state()
|
||||
if isinstance(self.state.stateful_callbacks[cb_name], list):
|
||||
self.state.stateful_callbacks[cb_name].append(cb_state)
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
|
||||
for cb in [
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
]:
|
||||
cb_name = cb.__class__.__name__
|
||||
cb_state = cb.state()
|
||||
if isinstance(self.state.stateful_callbacks[cb_name], list):
|
||||
self.state.stateful_callbacks[cb_name].append(cb_state)
|
||||
else:
|
||||
self.state.stateful_callbacks[cb_name] = cb_state
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
try:
|
||||
os.renames(output_dir, checkpoint_dir)
|
||||
except OSError as e:
|
||||
if e.errno in [errno.ENOTEMPTY, errno.EEXIST]: # Directory/File already exists
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
os.renames(output_dir, checkpoint_dir)
|
||||
else:
|
||||
self.state.stateful_callbacks[cb_name] = cb_state
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
raise
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
self._push_from_checkpoint(checkpoint_dir)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save:
|
||||
|
Loading…
Reference in New Issue
Block a user