mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Switch metrics in run_ner to datasets (#9567)
* Switch metrics in run_ner to datasets
* Add flag to return all metrics
* Upstream (and rename) sortish_sampler
* Revert "Upstream (and rename) sortish_sampler"
This reverts commit e07d0dcf65
.
This commit is contained in:
parent
5e1bea4f16
commit
46ed56cfd1
@ -184,7 +184,7 @@ class ExamplesTests(TestCasePlus):
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_ner.main()
|
||||
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
||||
self.assertLess(result["eval_loss"], 0.5)
|
||||
|
||||
|
@ -25,8 +25,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, load_dataset
|
||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||
from datasets import ClassLabel, load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@ -124,6 +123,10 @@ class DataTrainingArguments:
|
||||
"one (in which case the other tokens will have a padding index)."
|
||||
},
|
||||
)
|
||||
return_entity_level_metrics: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@ -323,6 +326,8 @@ def main():
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer)
|
||||
|
||||
# Metrics
|
||||
metric = load_metric("seqeval")
|
||||
|
||||
def compute_metrics(p):
|
||||
predictions, labels = p
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
@ -337,12 +342,24 @@ def main():
|
||||
for prediction, label in zip(predictions, labels)
|
||||
]
|
||||
|
||||
return {
|
||||
"accuracy_score": accuracy_score(true_labels, true_predictions),
|
||||
"precision": precision_score(true_labels, true_predictions),
|
||||
"recall": recall_score(true_labels, true_predictions),
|
||||
"f1": f1_score(true_labels, true_predictions),
|
||||
}
|
||||
results = metric.compute(predictions=true_predictions, references=true_labels)
|
||||
if data_args.return_entity_level_metrics:
|
||||
# Unpack nested dictionaries
|
||||
final_results = {}
|
||||
for key, value in results.items():
|
||||
if isinstance(value, dict):
|
||||
for n, v in value.items():
|
||||
final_results[f"{key}_{n}"] = v
|
||||
else:
|
||||
final_results[key] = value
|
||||
return final_results
|
||||
else:
|
||||
return {
|
||||
"precision": results["overall_precision"],
|
||||
"recall": results["overall_recall"],
|
||||
"f1": results["overall_f1"],
|
||||
"accuracy": results["overall_accuracy"],
|
||||
}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
Loading…
Reference in New Issue
Block a user