mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
New option called "best"
for args.save_strategy
. (#31817)
* Add _determine_best_metric and new saving logic. 1. Logic to determine the best logic was separated out from `_save_checkpoint`. 2. In `_maybe_log_save_evaluate`, whether or not a new best metric was achieved is determined after each evaluation, and if the save strategy is "best' then the TrainerControl is updated accordingly. * Added SaveStrategy. Same as IntervalStrategy, but with a new attribute called BEST. * IntervalStrategy -> SaveStrategy * IntervalStratgy -> SaveStrategy for save_strat. * Interval -> Save in docstring. * Updated docstring for save_strategy. * Added SaveStrategy and made according changes. `save_strategy` previously followed `IntervalStrategy` but now follows `SaveStrategy`. Changes were made accordingly to the code and the docstring. * Changes from `make fixup`. * Removed redundant metrics argument. * Added new test_save_best_checkpoint test. 1. Checks for both cases where `metric_for_best_model` is explicitly provided and when it's not provided. 2. The first case should have two checkpoints saved, whereas the second should have three saved. * Changed should_training_end saving logic. The Trainer saves a checkpoints at the end of training by default as long as `save_strategy != SaveStrategy.NO`. This condition was modified to include `SaveStrategy.BEST` because it would be counterintuitive that we'd only want the best checkpoint to be saved but the last one is as well. * `args.metric_for_best_model` default to loss. * Undo metric_for_best_model update. * Remove checking metric_for_best_model. * Added test cases for loss and no metric. * Added error for metric and changed default best_metric. * Removed unused import. * `new_best_metric` -> `is_new_best_metric` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Applied `is_new_best_metric` to all. Changes were made for consistency and also to fix a potential bug. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
8b3b9b48fc
commit
c1753436db
@ -117,9 +117,9 @@ from .trainer_utils import (
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
PredictionOutput,
|
||||
RemoveColumnsCollator,
|
||||
SaveStrategy,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
check_target_module_exists,
|
||||
@ -419,6 +419,12 @@ class Trainer:
|
||||
raise ValueError(
|
||||
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
|
||||
)
|
||||
if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
|
||||
if args.metric_for_best_model is None:
|
||||
raise ValueError(
|
||||
"`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
|
||||
)
|
||||
|
||||
self.args = args
|
||||
self.compute_loss_func = compute_loss_func
|
||||
# Seed must be set before instantiating the model when using model
|
||||
@ -2998,9 +3004,13 @@ class Trainer:
|
||||
metrics = None
|
||||
if self.control.should_evaluate:
|
||||
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
||||
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
||||
|
||||
if self.args.save_strategy == SaveStrategy.BEST:
|
||||
self.control.should_save = is_new_best_metric
|
||||
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial, metrics=metrics)
|
||||
self._save_checkpoint(model, trial)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
@ -3077,7 +3087,48 @@ class Trainer:
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
def _determine_best_metric(self, metrics, trial):
|
||||
"""
|
||||
Determine if the model should be saved based on the evaluation metrics.
|
||||
If args.metric_for_best_model is not set, the loss is used.
|
||||
|
||||
Returns:
|
||||
bool: True if a new best metric was found, else False
|
||||
"""
|
||||
is_new_best_metric = False
|
||||
|
||||
if self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
|
||||
try:
|
||||
metric_value = metrics[metric_to_check]
|
||||
except KeyError as exc:
|
||||
raise KeyError(
|
||||
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
||||
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
||||
) from exc
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
|
||||
if self.state.best_metric is None:
|
||||
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
||||
|
||||
if operator(metric_value, self.state.best_metric):
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
|
||||
is_new_best_metric = True
|
||||
|
||||
return is_new_best_metric
|
||||
|
||||
def _save_checkpoint(self, model, trial):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save except FullyShardedDDP.
|
||||
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||
@ -3098,31 +3149,6 @@ class Trainer:
|
||||
# Save RNG state
|
||||
self._save_rng_state(output_dir)
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
try:
|
||||
metric_value = metrics[metric_to_check]
|
||||
except KeyError as exc:
|
||||
raise KeyError(
|
||||
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
|
||||
f"which is not found in the evaluation metrics. "
|
||||
f"The available evaluation metrics are: {list(metrics.keys())}. "
|
||||
f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
|
||||
f"consider changing the `metric_for_best_model` via the TrainingArguments."
|
||||
) from exc
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
if (
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
):
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
|
||||
@ -4543,7 +4569,7 @@ class Trainer:
|
||||
# Same for the training arguments
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
if self.args.save_strategy == IntervalStrategy.STEPS:
|
||||
if self.args.save_strategy == SaveStrategy.STEPS:
|
||||
commit_message = f"Training in progress, step {self.state.global_step}"
|
||||
else:
|
||||
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
|
||||
|
@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import IntervalStrategy, has_length
|
||||
from .trainer_utils import IntervalStrategy, SaveStrategy, has_length
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
@ -555,7 +555,7 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
args.save_strategy == SaveStrategy.STEPS
|
||||
and state.save_steps > 0
|
||||
and state.global_step % state.save_steps == 0
|
||||
):
|
||||
@ -565,7 +565,7 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_training_stop = True
|
||||
# Save the model at the end if we have a save strategy
|
||||
if args.save_strategy != IntervalStrategy.NO:
|
||||
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
|
||||
control.should_save = True
|
||||
|
||||
return control
|
||||
@ -580,7 +580,7 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
control.should_evaluate = True
|
||||
|
||||
# Save
|
||||
if args.save_strategy == IntervalStrategy.EPOCH:
|
||||
if args.save_strategy == SaveStrategy.EPOCH:
|
||||
control.should_save = True
|
||||
|
||||
return control
|
||||
|
@ -227,6 +227,13 @@ class IntervalStrategy(ExplicitEnum):
|
||||
EPOCH = "epoch"
|
||||
|
||||
|
||||
class SaveStrategy(ExplicitEnum):
|
||||
NO = "no"
|
||||
STEPS = "steps"
|
||||
EPOCH = "epoch"
|
||||
BEST = "best"
|
||||
|
||||
|
||||
class EvaluationStrategy(ExplicitEnum):
|
||||
NO = "no"
|
||||
STEPS = "steps"
|
||||
|
@ -33,6 +33,7 @@ from .trainer_utils import (
|
||||
FSDPOption,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
SaveStrategy,
|
||||
SchedulerType,
|
||||
)
|
||||
from .utils import (
|
||||
@ -349,12 +350,13 @@ class TrainingArguments:
|
||||
|
||||
</Tip>
|
||||
|
||||
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
|
||||
save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
- `"no"`: No save is done during training.
|
||||
- `"epoch"`: Save is done at the end of each epoch.
|
||||
- `"steps"`: Save is done every `save_steps`.
|
||||
- `"best"`: Save is done whenever a new `best_metric` is achieved.
|
||||
|
||||
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
|
||||
very end of training, always.
|
||||
@ -962,7 +964,7 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
|
||||
save_strategy: Union[IntervalStrategy, str] = field(
|
||||
save_strategy: Union[SaveStrategy, str] = field(
|
||||
default="steps",
|
||||
metadata={"help": "The checkpoint save strategy to use."},
|
||||
)
|
||||
@ -1580,7 +1582,7 @@ class TrainingArguments:
|
||||
|
||||
self.eval_strategy = IntervalStrategy(self.eval_strategy)
|
||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||
self.save_strategy = SaveStrategy(self.save_strategy)
|
||||
self.hub_strategy = HubStrategy(self.hub_strategy)
|
||||
|
||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||
@ -1616,7 +1618,7 @@ class TrainingArguments:
|
||||
if self.eval_steps != int(self.eval_steps):
|
||||
raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
|
||||
self.eval_steps = int(self.eval_steps)
|
||||
if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1:
|
||||
if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1:
|
||||
if self.save_steps != int(self.save_steps):
|
||||
raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
|
||||
self.save_steps = int(self.save_steps)
|
||||
@ -2750,8 +2752,8 @@ class TrainingArguments:
|
||||
100
|
||||
```
|
||||
"""
|
||||
self.save_strategy = IntervalStrategy(strategy)
|
||||
if self.save_strategy == IntervalStrategy.STEPS and steps == 0:
|
||||
self.save_strategy = SaveStrategy(strategy)
|
||||
if self.save_strategy == SaveStrategy.STEPS and steps == 0:
|
||||
raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
|
||||
self.save_steps = steps
|
||||
self.save_total_limit = total_limit
|
||||
|
@ -114,7 +114,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
Whether to log and evaluate the first `global_step` or not.
|
||||
logging_steps (`int`, *optional*, defaults to 500):
|
||||
Number of update steps between two logs if `logging_strategy="steps"`.
|
||||
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
|
||||
save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
- `"no"`: No save is done during training.
|
||||
|
@ -4041,6 +4041,89 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
|
||||
)
|
||||
|
||||
def test_save_best_checkpoint(self):
|
||||
freq = int(64 / self.batch_size)
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
|
||||
# Case 1: args.metric_for_best_model == "accuracy".
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="best",
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "accuracy")
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=[
|
||||
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
|
||||
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
|
||||
{"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0},
|
||||
],
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
self.assertEqual(len(os.listdir(tmpdir)), 2)
|
||||
self.check_saved_checkpoints(
|
||||
output_dir=tmpdir,
|
||||
freq=freq,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Case 2: args.metric_for_best_model == "loss".
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="best",
|
||||
metric_for_best_model="loss",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "loss")
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=[
|
||||
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
|
||||
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
|
||||
{"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0},
|
||||
],
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
self.assertEqual(len(os.listdir(tmpdir)), 2)
|
||||
self.check_saved_checkpoints(
|
||||
output_dir=tmpdir,
|
||||
freq=freq,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Case 3: Metric name not provided; throw error.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="best",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
|
||||
self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
Loading…
Reference in New Issue
Block a user