Switch metrics in run_ner to datasets (#9567)

* Switch metrics in run_ner to datasets

* Add flag to return all metrics

* Upstream (and rename) sortish_sampler

* Revert "Upstream (and rename) sortish_sampler"

This reverts commit e07d0dcf65.
This commit is contained in:
Sylvain Gugger 2021-01-14 03:37:07 -05:00 committed by GitHub
parent 5e1bea4f16
commit 46ed56cfd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 9 deletions

View File

@ -184,7 +184,7 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs):
result = run_ner.main()
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_precision"], 0.75)
self.assertLess(result["eval_loss"], 0.5)

View File

@ -25,8 +25,7 @@ from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from datasets import ClassLabel, load_dataset
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from datasets import ClassLabel, load_dataset, load_metric
import transformers
from transformers import (
@ -124,6 +123,10 @@ class DataTrainingArguments:
"one (in which case the other tokens will have a padding index)."
},
)
return_entity_level_metrics: bool = field(
default=False,
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@ -323,6 +326,8 @@ def main():
data_collator = DataCollatorForTokenClassification(tokenizer)
# Metrics
metric = load_metric("seqeval")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
@ -337,12 +342,24 @@ def main():
for prediction, label in zip(predictions, labels)
]
return {
"accuracy_score": accuracy_score(true_labels, true_predictions),
"precision": precision_score(true_labels, true_predictions),
"recall": recall_score(true_labels, true_predictions),
"f1": f1_score(true_labels, true_predictions),
}
results = metric.compute(predictions=true_predictions, references=true_labels)
if data_args.return_entity_level_metrics:
# Unpack nested dictionaries
final_results = {}
for key, value in results.items():
if isinstance(value, dict):
for n, v in value.items():
final_results[f"{key}_{n}"] = v
else:
final_results[key] = value
return final_results
else:
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
# Initialize our Trainer
trainer = Trainer(