diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 05848ffd856..7b72ee767f2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2128,6 +2128,9 @@ class Trainer: return model + + + def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -2222,32 +2225,37 @@ class Trainer: self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size ) - try: - if args.push_to_hub: - try: - # Disable progress bars when uploading models during checkpoints to avoid polluting stdout - hf_hub_utils.disable_progress_bars() - return inner_training_loop( - args=args, - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - ) - finally: - hf_hub_utils.enable_progress_bars() - else: - return inner_training_loop( + + def _safe_inner_training_loop(*loop_args, **loop_kwargs): + try: + return inner_training_loop(*loop_args, **loop_kwargs) + except BaseException as e: # BaseException also catches KeyboardInterrupt / OOM + logger.error(f"Caught unexpected exception during training: {e}") + if self.args.enable_emergency_checkpoint: + self._common_emergency_save("training_exception") + raise # re-raise for normal debugging + + + if args.push_to_hub: + hf_hub_utils.disable_progress_bars() + try: + return _safe_inner_training_loop( args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, ) - except Exception as e: - # This is where you explicitly call your emergency save logic - logger.error(f"Caught an unexpected exception during training: {e}") - if self.args.enable_emergency_checkpoint: - self._common_emergency_save("training_exception") - raise e # Re-raise the exception to signal that training failed + finally: + hf_hub_utils.enable_progress_bars() + else: + return _safe_inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def get_tp_size(self) -> int: """Get the tensor parallel size from either the model or DeepSpeed config.""" @@ -3282,7 +3290,7 @@ class Trainer: # Check if an emergency save has already completed or is in progress from another trigger # This prevents the atexit handler from running if the try...except in train() already initiated a save - if self._emergency_save_completed_this_run: # <--- ADD THIS CHECK HERE + if self._emergency_save_completed_this_run: return self._emergency_save_running = True @@ -3306,7 +3314,7 @@ class Trainer: minimal_state = { "global_step": self.state.global_step, "epoch": self.state.epoch, - "is_world_process_zero": self.state.is_world_process_zero, # Useful for sanity check on resume + "is_world_process_zero": self.state.is_world_process_zero, "is_local_process_zero": self.state.is_local_process_zero, } with open(minimal_state_path, "w") as f: