mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Reduced the scope of try except block for easier debugging in train method for emergency checkpointing logic.
This commit is contained in:
parent
68663be9e1
commit
cb9655a06e
@ -2128,6 +2128,9 @@ class Trainer:
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[Union[str, bool]] = None,
|
||||
@ -2222,32 +2225,37 @@ class Trainer:
|
||||
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
|
||||
)
|
||||
|
||||
try:
|
||||
if args.push_to_hub:
|
||||
try:
|
||||
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
|
||||
hf_hub_utils.disable_progress_bars()
|
||||
return inner_training_loop(
|
||||
args=args,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
trial=trial,
|
||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||
)
|
||||
finally:
|
||||
hf_hub_utils.enable_progress_bars()
|
||||
else:
|
||||
return inner_training_loop(
|
||||
|
||||
def _safe_inner_training_loop(*loop_args, **loop_kwargs):
|
||||
try:
|
||||
return inner_training_loop(*loop_args, **loop_kwargs)
|
||||
except BaseException as e: # BaseException also catches KeyboardInterrupt / OOM
|
||||
logger.error(f"Caught unexpected exception during training: {e}")
|
||||
if self.args.enable_emergency_checkpoint:
|
||||
self._common_emergency_save("training_exception")
|
||||
raise # re-raise for normal debugging
|
||||
|
||||
|
||||
if args.push_to_hub:
|
||||
hf_hub_utils.disable_progress_bars()
|
||||
try:
|
||||
return _safe_inner_training_loop(
|
||||
args=args,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
trial=trial,
|
||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||
)
|
||||
except Exception as e:
|
||||
# This is where you explicitly call your emergency save logic
|
||||
logger.error(f"Caught an unexpected exception during training: {e}")
|
||||
if self.args.enable_emergency_checkpoint:
|
||||
self._common_emergency_save("training_exception")
|
||||
raise e # Re-raise the exception to signal that training failed
|
||||
finally:
|
||||
hf_hub_utils.enable_progress_bars()
|
||||
else:
|
||||
return _safe_inner_training_loop(
|
||||
args=args,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
trial=trial,
|
||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_tp_size(self) -> int:
|
||||
"""Get the tensor parallel size from either the model or DeepSpeed config."""
|
||||
@ -3282,7 +3290,7 @@ class Trainer:
|
||||
|
||||
# Check if an emergency save has already completed or is in progress from another trigger
|
||||
# This prevents the atexit handler from running if the try...except in train() already initiated a save
|
||||
if self._emergency_save_completed_this_run: # <--- ADD THIS CHECK HERE
|
||||
if self._emergency_save_completed_this_run:
|
||||
return
|
||||
|
||||
self._emergency_save_running = True
|
||||
@ -3306,7 +3314,7 @@ class Trainer:
|
||||
minimal_state = {
|
||||
"global_step": self.state.global_step,
|
||||
"epoch": self.state.epoch,
|
||||
"is_world_process_zero": self.state.is_world_process_zero, # Useful for sanity check on resume
|
||||
"is_world_process_zero": self.state.is_world_process_zero,
|
||||
"is_local_process_zero": self.state.is_local_process_zero,
|
||||
}
|
||||
with open(minimal_state_path, "w") as f:
|
||||
|
Loading…
Reference in New Issue
Block a user