mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
fix: prevent second save in the end of training if last step was saved already (#36219)
* fix: prevent second save in the end of training * fix: prevent second save in the end of training * test: added test for no duplicate save on epoch save strategy * fix: removed TrainerControl * chore: style formatting --------- Co-authored-by: JaktensTid <jaktenstid1@gmail.com>
This commit is contained in:
parent
5412ff1a13
commit
effaef334b
@ -600,7 +600,7 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
if state.global_step >= state.max_steps:
|
if state.global_step >= state.max_steps:
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
# Save the model at the end if we have a save strategy
|
# Save the model at the end if we have a save strategy
|
||||||
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
|
if args.save_strategy == SaveStrategy.STEPS:
|
||||||
control.should_save = True
|
control.should_save = True
|
||||||
|
|
||||||
return control
|
return control
|
||||||
|
@ -425,3 +425,21 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
|
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
|
||||||
trainer._load_callback_state()
|
trainer._load_callback_state()
|
||||||
assert trainer.control.should_training_stop
|
assert trainer.control.should_training_stop
|
||||||
|
|
||||||
|
def test_no_duplicate_save_on_epoch_save_strategy(self):
|
||||||
|
times_saved = 0
|
||||||
|
|
||||||
|
class OnEndCallback(TrainerCallback):
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
|
||||||
|
nonlocal times_saved
|
||||||
|
if control.should_save:
|
||||||
|
times_saved += 1
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
|
||||||
|
nonlocal times_saved
|
||||||
|
if control.should_save:
|
||||||
|
times_saved += 1
|
||||||
|
|
||||||
|
trainer = self.get_trainer(max_steps=2, save_strategy="epoch", callbacks=[OnEndCallback])
|
||||||
|
trainer.train()
|
||||||
|
assert times_saved == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user