mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix: prevent second save in the end of training if last step was saved already (#36219)
* fix: prevent second save in the end of training * fix: prevent second save in the end of training * test: added test for no duplicate save on epoch save strategy * fix: removed TrainerControl * chore: style formatting --------- Co-authored-by: JaktensTid <jaktenstid1@gmail.com>
This commit is contained in:
parent
5412ff1a13
commit
effaef334b
@ -600,7 +600,7 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_training_stop = True
|
||||
# Save the model at the end if we have a save strategy
|
||||
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
|
||||
if args.save_strategy == SaveStrategy.STEPS:
|
||||
control.should_save = True
|
||||
|
||||
return control
|
||||
|
@ -425,3 +425,21 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
|
||||
trainer._load_callback_state()
|
||||
assert trainer.control.should_training_stop
|
||||
|
||||
def test_no_duplicate_save_on_epoch_save_strategy(self):
|
||||
times_saved = 0
|
||||
|
||||
class OnEndCallback(TrainerCallback):
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
|
||||
nonlocal times_saved
|
||||
if control.should_save:
|
||||
times_saved += 1
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
|
||||
nonlocal times_saved
|
||||
if control.should_save:
|
||||
times_saved += 1
|
||||
|
||||
trainer = self.get_trainer(max_steps=2, save_strategy="epoch", callbacks=[OnEndCallback])
|
||||
trainer.train()
|
||||
assert times_saved == 1
|
||||
|
Loading…
Reference in New Issue
Block a user