Reduced the scope of try except block for easier debugging in train method for emergency checkpointing logic.

This commit is contained in:
AyushSharma173 2025-06-30 19:52:23 -05:00
parent 68663be9e1
commit cb9655a06e

View File

@ -2128,6 +2128,9 @@ class Trainer:
return model
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
@ -2222,32 +2225,37 @@ class Trainer:
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
)
try:
if args.push_to_hub:
try:
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
hf_hub_utils.disable_progress_bars()
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
finally:
hf_hub_utils.enable_progress_bars()
else:
return inner_training_loop(
def _safe_inner_training_loop(*loop_args, **loop_kwargs):
try:
return inner_training_loop(*loop_args, **loop_kwargs)
except BaseException as e: # BaseException also catches KeyboardInterrupt / OOM
logger.error(f"Caught unexpected exception during training: {e}")
if self.args.enable_emergency_checkpoint:
self._common_emergency_save("training_exception")
raise # re-raise for normal debugging
if args.push_to_hub:
hf_hub_utils.disable_progress_bars()
try:
return _safe_inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
except Exception as e:
# This is where you explicitly call your emergency save logic
logger.error(f"Caught an unexpected exception during training: {e}")
if self.args.enable_emergency_checkpoint:
self._common_emergency_save("training_exception")
raise e # Re-raise the exception to signal that training failed
finally:
hf_hub_utils.enable_progress_bars()
else:
return _safe_inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
def get_tp_size(self) -> int:
"""Get the tensor parallel size from either the model or DeepSpeed config."""
@ -3282,7 +3290,7 @@ class Trainer:
# Check if an emergency save has already completed or is in progress from another trigger
# This prevents the atexit handler from running if the try...except in train() already initiated a save
if self._emergency_save_completed_this_run: # <--- ADD THIS CHECK HERE
if self._emergency_save_completed_this_run:
return
self._emergency_save_running = True
@ -3306,7 +3314,7 @@ class Trainer:
minimal_state = {
"global_step": self.state.global_step,
"epoch": self.state.epoch,
"is_world_process_zero": self.state.is_world_process_zero, # Useful for sanity check on resume
"is_world_process_zero": self.state.is_world_process_zero,
"is_local_process_zero": self.state.is_local_process_zero,
}
with open(minimal_state_path, "w") as f: