diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 7b711f65701..8f241a9db4a 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -707,10 +707,14 @@ class EarlyStoppingCallback(TrainerCallback, ExportableState): self.early_stopping_patience_counter += 1 def on_train_begin(self, args, state, control, **kwargs): - assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True" + if not args.load_best_model_at_end: + logger.warning( + "Using EarlyStoppingCallback without load_best_model_at_end=True. " + "Once training is finished, the best model will not be loaded automatically." + ) assert ( args.metric_for_best_model is not None - ), "EarlyStoppingCallback requires metric_for_best_model is defined" + ), "EarlyStoppingCallback requires metric_for_best_model to be defined" assert ( args.eval_strategy != IntervalStrategy.NO ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1703cb40098..d89c4aa8030 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3484,6 +3484,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): except AssertionError: self.assertEqual(trainer.state.global_step, 0) + # even if load_best_model_at_end is False, `best_model_checkpoint` should be set + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=tmp_dir, + num_train_epochs=20, + gradient_accumulation_steps=1, + per_device_train_batch_size=16, + load_best_model_at_end=False, + eval_strategy=IntervalStrategy.EPOCH, + save_strategy=IntervalStrategy.EPOCH, + compute_metrics=AlmostAccuracy(), + metric_for_best_model="accuracy", + ) + trainer.add_callback(EarlyStoppingCallback(1, 0.0001)) + train_output = trainer.train() + self.assertIsNotNone(trainer.state.best_model_checkpoint) + def test_flos_extraction(self): with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer(learning_rate=0.1, output_dir=tmp_dir)