mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Update doc for metric_for_best_model
when save_strategy="best"
. (#35389)
* Updated docstring for _determine_best_metric. * Updated docstring for metric_for_best_model. * Added test case for save strategy. * Updated incorrect test case. * Changed eval_strategy to match save_strategy. * Separated test cases for metric. * Allow load_best_model when save_strategy == "best". * Updated docstring for metric_for_best_model.
This commit is contained in:
parent
29e74b7cbc
commit
88e18b3c63
@ -3156,7 +3156,6 @@ class Trainer:
|
||||
def _determine_best_metric(self, metrics, trial):
|
||||
"""
|
||||
Determine if the model should be saved based on the evaluation metrics.
|
||||
If args.metric_for_best_model is not set, the loss is used.
|
||||
|
||||
Returns:
|
||||
bool: True if a new best metric was found, else False
|
||||
|
@ -476,11 +476,13 @@ class TrainingArguments:
|
||||
|
||||
metric_for_best_model (`str`, *optional*):
|
||||
Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`. Will
|
||||
default to `"loss"` if unspecified and `load_best_model_at_end=True` (to use the evaluation loss).
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`.
|
||||
|
||||
If you set this value, `greater_is_better` will default to `True`. Don't forget to set it to `False` if
|
||||
your metric is better when lower.
|
||||
If not specified, this will default to `"loss"` when either `load_best_model_at_end == True`
|
||||
or `lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU` (to use the evaluation loss).
|
||||
|
||||
If you set this value, `greater_is_better` will default to `True` unless the name ends with "loss".
|
||||
Don't forget to set it to `False` if your metric is better when lower.
|
||||
greater_is_better (`bool`, *optional*):
|
||||
Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models
|
||||
should have a greater metric or not. Will default to:
|
||||
@ -1636,7 +1638,7 @@ class TrainingArguments:
|
||||
self.save_steps = int(self.save_steps)
|
||||
|
||||
# Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
|
||||
if self.load_best_model_at_end:
|
||||
if self.load_best_model_at_end and self.save_strategy != SaveStrategy.BEST:
|
||||
if self.eval_strategy != self.save_strategy:
|
||||
raise ValueError(
|
||||
"--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
|
||||
|
@ -4220,7 +4220,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Case 3: Metric name not provided; throw error.
|
||||
def test_metric_for_best_model_behavior(self):
|
||||
# Case 1: Metric name not provided when `save_strategy == "best"`.
|
||||
# Should raise ValueError.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
trainer = get_regression_trainer(
|
||||
@ -4232,9 +4234,22 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
save_strategy="best",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
|
||||
self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))
|
||||
|
||||
# Case 2: Metric name not provided when `load_best_model_at_end == True`.
|
||||
# `metric_for_best_model` should be set to `"loss"` by default.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="steps",
|
||||
save_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "loss")
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
Loading…
Reference in New Issue
Block a user