Reduced the scope of try except block for easier debugging in train method for emergency checkpointing logic.

This commit is contained in:
AyushSharma173 2025-06-30 19:52:23 -05:00
parent 68663be9e1
commit cb9655a06e

View File

@ -2128,6 +2128,9 @@ class Trainer:
return model return model
def train( def train(
self, self,
resume_from_checkpoint: Optional[Union[str, bool]] = None, resume_from_checkpoint: Optional[Union[str, bool]] = None,
@ -2222,32 +2225,37 @@ class Trainer:
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
) )
try:
if args.push_to_hub: def _safe_inner_training_loop(*loop_args, **loop_kwargs):
try: try:
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout return inner_training_loop(*loop_args, **loop_kwargs)
hf_hub_utils.disable_progress_bars() except BaseException as e: # BaseException also catches KeyboardInterrupt / OOM
return inner_training_loop( logger.error(f"Caught unexpected exception during training: {e}")
args=args, if self.args.enable_emergency_checkpoint:
resume_from_checkpoint=resume_from_checkpoint, self._common_emergency_save("training_exception")
trial=trial, raise # re-raise for normal debugging
ignore_keys_for_eval=ignore_keys_for_eval,
)
finally: if args.push_to_hub:
hf_hub_utils.enable_progress_bars() hf_hub_utils.disable_progress_bars()
else: try:
return inner_training_loop( return _safe_inner_training_loop(
args=args, args=args,
resume_from_checkpoint=resume_from_checkpoint, resume_from_checkpoint=resume_from_checkpoint,
trial=trial, trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval, ignore_keys_for_eval=ignore_keys_for_eval,
) )
except Exception as e: finally:
# This is where you explicitly call your emergency save logic hf_hub_utils.enable_progress_bars()
logger.error(f"Caught an unexpected exception during training: {e}") else:
if self.args.enable_emergency_checkpoint: return _safe_inner_training_loop(
self._common_emergency_save("training_exception") args=args,
raise e # Re-raise the exception to signal that training failed resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
def get_tp_size(self) -> int: def get_tp_size(self) -> int:
"""Get the tensor parallel size from either the model or DeepSpeed config.""" """Get the tensor parallel size from either the model or DeepSpeed config."""
@ -3282,7 +3290,7 @@ class Trainer:
# Check if an emergency save has already completed or is in progress from another trigger # Check if an emergency save has already completed or is in progress from another trigger
# This prevents the atexit handler from running if the try...except in train() already initiated a save # This prevents the atexit handler from running if the try...except in train() already initiated a save
if self._emergency_save_completed_this_run: # <--- ADD THIS CHECK HERE if self._emergency_save_completed_this_run:
return return
self._emergency_save_running = True self._emergency_save_running = True
@ -3306,7 +3314,7 @@ class Trainer:
minimal_state = { minimal_state = {
"global_step": self.state.global_step, "global_step": self.state.global_step,
"epoch": self.state.epoch, "epoch": self.state.epoch,
"is_world_process_zero": self.state.is_world_process_zero, # Useful for sanity check on resume "is_world_process_zero": self.state.is_world_process_zero,
"is_local_process_zero": self.state.is_local_process_zero, "is_local_process_zero": self.state.is_local_process_zero,
} }
with open(minimal_state_path, "w") as f: with open(minimal_state_path, "w") as f: