fix: prevent second save in the end of training if last step was saved already (#36219)

* fix: prevent second save in the end of training

* fix: prevent second save in the end of training

* test: added test for no duplicate save on epoch save strategy

* fix: removed TrainerControl

* chore: style formatting

---------

Co-authored-by: JaktensTid <jaktenstid1@gmail.com>
This commit is contained in:
Nosimus 2025-02-20 20:38:52 +04:00 committed by GitHub
parent 5412ff1a13
commit effaef334b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 1 deletions

View File

@ -600,7 +600,7 @@ class DefaultFlowCallback(TrainerCallback):
if state.global_step >= state.max_steps:
control.should_training_stop = True
# Save the model at the end if we have a save strategy
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
if args.save_strategy == SaveStrategy.STEPS:
control.should_save = True
return control

View File

@ -425,3 +425,21 @@ class TrainerCallbackTest(unittest.TestCase):
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
trainer._load_callback_state()
assert trainer.control.should_training_stop
def test_no_duplicate_save_on_epoch_save_strategy(self):
times_saved = 0
class OnEndCallback(TrainerCallback):
def on_step_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
nonlocal times_saved
if control.should_save:
times_saved += 1
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
nonlocal times_saved
if control.should_save:
times_saved += 1
trainer = self.get_trainer(max_steps=2, save_strategy="epoch", callbacks=[OnEndCallback])
trainer.train()
assert times_saved == 1