From 20b658607e9dc2b723d3571be16104832e0ef386 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 19 Nov 2020 13:59:30 -0500 Subject: [PATCH] Fix run_ner script (#8664) * Fix run_ner script * Pin datasets --- examples/requirements.txt | 2 +- examples/token-classification/run_ner.py | 25 +++++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/examples/requirements.txt b/examples/requirements.txt index 1ce783440f6..cb218847c67 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -13,7 +13,7 @@ streamlit elasticsearch nltk pandas -datasets +datasets >= 1.1.3 fire pytest conllu diff --git a/examples/token-classification/run_ner.py b/examples/token-classification/run_ner.py index 718927f3ebf..b6232bbed07 100644 --- a/examples/token-classification/run_ner.py +++ b/examples/token-classification/run_ner.py @@ -15,7 +15,8 @@ """ Fine-tuning the library models for token classification. """ -# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments. +# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as +# comments. import logging import os @@ -24,7 +25,7 @@ from dataclasses import dataclass, field from typing import Optional import numpy as np -from datasets import load_dataset +from datasets import ClassLabel, load_dataset from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score import transformers @@ -198,12 +199,17 @@ def main(): if training_args.do_train: column_names = datasets["train"].column_names + features = datasets["train"].features else: column_names = datasets["validation"].column_names - text_column_name = "words" if "words" in column_names else column_names[0] - label_column_name = data_args.task_name if data_args.task_name in column_names else column_names[1] + features = datasets["validation"].features + text_column_name = "tokens" if "tokens" in column_names else column_names[0] + label_column_name = ( + f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1] + ) - # Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved) + # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the + # unique labels. def get_label_list(labels): unique_labels = set() for label in labels: @@ -212,8 +218,13 @@ def main(): label_list.sort() return label_list - label_list = get_label_list(datasets["train"][label_column_name]) - label_to_id = {l: i for i, l in enumerate(label_list)} + if isinstance(features[label_column_name].feature, ClassLabel): + label_list = features[label_column_name].feature.names + # No need to convert the labels since they are already ints. + label_to_id = {i: i for i in range(len(label_list))} + else: + label_list = get_label_list(datasets["train"][label_column_name]) + label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) # Load pretrained model and tokenizer