mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Reduced the scope of try except block for easier debugging in train method for emergency checkpointing logic.
This commit is contained in:
parent
68663be9e1
commit
cb9655a06e
@ -2128,6 +2128,9 @@ class Trainer:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
resume_from_checkpoint: Optional[Union[str, bool]] = None,
|
resume_from_checkpoint: Optional[Union[str, bool]] = None,
|
||||||
@ -2222,32 +2225,37 @@ class Trainer:
|
|||||||
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
|
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
if args.push_to_hub:
|
def _safe_inner_training_loop(*loop_args, **loop_kwargs):
|
||||||
try:
|
try:
|
||||||
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
|
return inner_training_loop(*loop_args, **loop_kwargs)
|
||||||
hf_hub_utils.disable_progress_bars()
|
except BaseException as e: # BaseException also catches KeyboardInterrupt / OOM
|
||||||
return inner_training_loop(
|
logger.error(f"Caught unexpected exception during training: {e}")
|
||||||
args=args,
|
if self.args.enable_emergency_checkpoint:
|
||||||
resume_from_checkpoint=resume_from_checkpoint,
|
self._common_emergency_save("training_exception")
|
||||||
trial=trial,
|
raise # re-raise for normal debugging
|
||||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
|
||||||
)
|
|
||||||
finally:
|
if args.push_to_hub:
|
||||||
hf_hub_utils.enable_progress_bars()
|
hf_hub_utils.disable_progress_bars()
|
||||||
else:
|
try:
|
||||||
return inner_training_loop(
|
return _safe_inner_training_loop(
|
||||||
args=args,
|
args=args,
|
||||||
resume_from_checkpoint=resume_from_checkpoint,
|
resume_from_checkpoint=resume_from_checkpoint,
|
||||||
trial=trial,
|
trial=trial,
|
||||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
finally:
|
||||||
# This is where you explicitly call your emergency save logic
|
hf_hub_utils.enable_progress_bars()
|
||||||
logger.error(f"Caught an unexpected exception during training: {e}")
|
else:
|
||||||
if self.args.enable_emergency_checkpoint:
|
return _safe_inner_training_loop(
|
||||||
self._common_emergency_save("training_exception")
|
args=args,
|
||||||
raise e # Re-raise the exception to signal that training failed
|
resume_from_checkpoint=resume_from_checkpoint,
|
||||||
|
trial=trial,
|
||||||
|
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_tp_size(self) -> int:
|
def get_tp_size(self) -> int:
|
||||||
"""Get the tensor parallel size from either the model or DeepSpeed config."""
|
"""Get the tensor parallel size from either the model or DeepSpeed config."""
|
||||||
@ -3282,7 +3290,7 @@ class Trainer:
|
|||||||
|
|
||||||
# Check if an emergency save has already completed or is in progress from another trigger
|
# Check if an emergency save has already completed or is in progress from another trigger
|
||||||
# This prevents the atexit handler from running if the try...except in train() already initiated a save
|
# This prevents the atexit handler from running if the try...except in train() already initiated a save
|
||||||
if self._emergency_save_completed_this_run: # <--- ADD THIS CHECK HERE
|
if self._emergency_save_completed_this_run:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._emergency_save_running = True
|
self._emergency_save_running = True
|
||||||
@ -3306,7 +3314,7 @@ class Trainer:
|
|||||||
minimal_state = {
|
minimal_state = {
|
||||||
"global_step": self.state.global_step,
|
"global_step": self.state.global_step,
|
||||||
"epoch": self.state.epoch,
|
"epoch": self.state.epoch,
|
||||||
"is_world_process_zero": self.state.is_world_process_zero, # Useful for sanity check on resume
|
"is_world_process_zero": self.state.is_world_process_zero,
|
||||||
"is_local_process_zero": self.state.is_local_process_zero,
|
"is_local_process_zero": self.state.is_local_process_zero,
|
||||||
}
|
}
|
||||||
with open(minimal_state_path, "w") as f:
|
with open(minimal_state_path, "w") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user