mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Let EarlyStoppingCallback
not require load_best_model_at_end
(#35101)
* Bookmark * Add warning
This commit is contained in:
parent
0aaf124fb9
commit
b02828e4af
@ -707,10 +707,14 @@ class EarlyStoppingCallback(TrainerCallback, ExportableState):
|
||||
self.early_stopping_patience_counter += 1
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
|
||||
if not args.load_best_model_at_end:
|
||||
logger.warning(
|
||||
"Using EarlyStoppingCallback without load_best_model_at_end=True. "
|
||||
"Once training is finished, the best model will not be loaded automatically."
|
||||
)
|
||||
assert (
|
||||
args.metric_for_best_model is not None
|
||||
), "EarlyStoppingCallback requires metric_for_best_model is defined"
|
||||
), "EarlyStoppingCallback requires metric_for_best_model to be defined"
|
||||
assert (
|
||||
args.eval_strategy != IntervalStrategy.NO
|
||||
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
|
||||
|
@ -3484,6 +3484,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
except AssertionError:
|
||||
self.assertEqual(trainer.state.global_step, 0)
|
||||
|
||||
# even if load_best_model_at_end is False, `best_model_checkpoint` should be set
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir,
|
||||
num_train_epochs=20,
|
||||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
load_best_model_at_end=False,
|
||||
eval_strategy=IntervalStrategy.EPOCH,
|
||||
save_strategy=IntervalStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
|
||||
train_output = trainer.train()
|
||||
self.assertIsNotNone(trainer.state.best_model_checkpoint)
|
||||
|
||||
def test_flos_extraction(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(learning_rate=0.1, output_dir=tmp_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user