mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enforce eval and save strategies are compatible when --load_best_model_at_end (#12786)
* Enforce eval and save strategies are compatible when --load_best_model_at_end * Update doc * Fix typos * Fix tests
This commit is contained in:
parent
546dc24e08
commit
0118ef89ee
@ -418,8 +418,7 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
|
||||
# Save
|
||||
if (
|
||||
not args.load_best_model_at_end
|
||||
and args.save_strategy == IntervalStrategy.STEPS
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
@ -439,8 +438,6 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if args.save_strategy == IntervalStrategy.EPOCH:
|
||||
|
@ -172,8 +172,7 @@ class TrainingArguments:
|
||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The checkpoint save strategy to adopt during training (Note that when :obj:`load_best_model_at_end=True`,
|
||||
this parameter is ignored and the model is saved after each evaluation). Possible values are:
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
@ -247,8 +246,9 @@ class TrainingArguments:
|
||||
|
||||
.. note::
|
||||
|
||||
When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
|
||||
the model will be saved after each evaluation.
|
||||
When set to :obj:`True`, the parameters :obj:`save_strategy` needs to be the same as
|
||||
:obj:`eval_strategy`, and in the case it is "steps", :obj:`save_steps` must be a round multiple of
|
||||
:obj:`eval_steps`.
|
||||
metric_for_best_model (:obj:`str`, `optional`):
|
||||
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
|
||||
@ -666,6 +666,19 @@ class TrainingArguments:
|
||||
if self.eval_steps is None:
|
||||
self.eval_steps = self.logging_steps
|
||||
|
||||
# Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
|
||||
if self.load_best_model_at_end:
|
||||
if self.evaluation_strategy != self.save_strategy:
|
||||
raise ValueError(
|
||||
"--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
|
||||
f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}"
|
||||
)
|
||||
if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
|
||||
raise ValueError(
|
||||
"--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation "
|
||||
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
|
||||
)
|
||||
|
||||
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
||||
self.metric_for_best_model = "loss"
|
||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||
|
@ -915,6 +915,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
save_steps=5,
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
self.assertFalse(trainer.args.greater_is_better)
|
||||
@ -930,6 +931,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
save_steps=5,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
@ -939,7 +941,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_saved_checkpoints(tmpdir, 5, total)
|
||||
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)
|
||||
|
||||
# Save is done every eval regardless of the strategy
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
@ -947,6 +948,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
@ -965,6 +967,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
save_steps=5,
|
||||
load_best_model_at_end=True,
|
||||
pretrained=False,
|
||||
)
|
||||
@ -1083,6 +1086,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
per_device_train_batch_size=16,
|
||||
load_best_model_at_end=True,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
save_strategy=IntervalStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
@ -1140,13 +1144,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [20, 25])
|
||||
|
||||
# With best model at end
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=2)
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir, evaluation_strategy="steps", load_best_model_at_end=True, save_total_limit=2
|
||||
)
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||
|
||||
# Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume
|
||||
# from checkpoint
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=1)
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir, evaluation_strategy="steps", load_best_model_at_end=True, save_total_limit=1
|
||||
)
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-25")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [25])
|
||||
|
||||
@ -1350,6 +1358,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
save_strategy=IntervalStrategy.EPOCH,
|
||||
num_train_epochs=4,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
@ -1400,6 +1409,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
save_strategy=IntervalStrategy.EPOCH,
|
||||
num_train_epochs=4,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
|
Loading…
Reference in New Issue
Block a user