mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Tidy Pytorch GLUE benchmark example (#23134)
Migration to Evaluate for metric is not quite complete
This commit is contained in:
parent
b0a78091a5
commit
b6933d76d2
@ -486,6 +486,8 @@ def main():
|
||||
# Get the metric function
|
||||
if data_args.task_name is not None:
|
||||
metric = evaluate.load("glue", data_args.task_name)
|
||||
elif is_regression:
|
||||
metric = evaluate.load("mse")
|
||||
else:
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
@ -494,15 +496,10 @@ def main():
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
|
||||
if data_args.task_name is not None:
|
||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||
if len(result) > 1:
|
||||
result["combined_score"] = np.mean(list(result.values())).item()
|
||||
return result
|
||||
elif is_regression:
|
||||
return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
|
||||
else:
|
||||
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||
if len(result) > 1:
|
||||
result["combined_score"] = np.mean(list(result.values())).item()
|
||||
return result
|
||||
|
||||
# Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
|
||||
# we already did the padding.
|
||||
|
Loading…
Reference in New Issue
Block a user