diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 45dfcee71bc..a1c98a64aef 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -600,7 +600,7 @@ class DefaultFlowCallback(TrainerCallback): if state.global_step >= state.max_steps: control.should_training_stop = True # Save the model at the end if we have a save strategy - if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]: + if args.save_strategy == SaveStrategy.STEPS: control.should_save = True return control diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 0d1e6645f9a..996cd1ecb9a 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -425,3 +425,21 @@ class TrainerCallbackTest(unittest.TestCase): trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME)) trainer._load_callback_state() assert trainer.control.should_training_stop + + def test_no_duplicate_save_on_epoch_save_strategy(self): + times_saved = 0 + + class OnEndCallback(TrainerCallback): + def on_step_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs): + nonlocal times_saved + if control.should_save: + times_saved += 1 + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs): + nonlocal times_saved + if control.should_save: + times_saved += 1 + + trainer = self.get_trainer(max_steps=2, save_strategy="epoch", callbacks=[OnEndCallback]) + trainer.train() + assert times_saved == 1