move code to Trainer.evaluate to enable use of that function with multiple datasets (#27844)

* move code to Trainer.evaluate to enable use of that function with multiple datasets

* test

* update doc string

* and a tip

* forgot the type

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
This commit is contained in:
peter-sk 2023-12-20 10:55:56 +01:00 committed by GitHub
parent cd9f9d63f1
commit 769a9542de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 15 deletions

View File

@ -2261,17 +2261,7 @@ class Trainer:
metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
@ -2997,7 +2987,7 @@ class Trainer:
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
@ -3010,10 +3000,24 @@ class Trainer:
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
`__len__` method.
<Tip>
If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
separate evaluations on each dataset. This can be useful to monitor how training affects other
datasets or simply to get a more fine-grained evaluation.
When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
`data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`.
</Tip>
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
@ -3025,6 +3029,19 @@ class Trainer:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
# handle multipe eval datasets
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics
# memory metrics - must set up as early as possible
self._memory_tracker.start()

View File

@ -103,6 +103,7 @@ if is_torch_available():
import transformers.optimization
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
EarlyStoppingCallback,
GlueDataset,
@ -1845,6 +1846,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
result = trainer.evaluate()
self.assertLess(result["eval_loss"], 0.2)
@slow
def test_trainer_eval_multiple(self):
MODEL_ID = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in dataset.examples:
example["labels"] = example["input_ids"]
training_args = TrainingArguments(
output_dir="./examples",
use_cpu=True,
per_device_eval_batch_size=1,
)
trainer = Trainer(
model=model,
args=training_args,
eval_dataset={
"data1": dataset,
"data2": dataset,
},
)
result = trainer.evaluate()
self.assertIn("eval_data1_loss", result)
self.assertIn("eval_data2_loss", result)
@slow
def test_trainer_eval_lm(self):
MODEL_ID = "distilroberta-base"