Fix run_ner script (#8664)

* Fix run_ner script

* Pin datasets
This commit is contained in:
Sylvain Gugger 2020-11-19 13:59:30 -05:00 committed by GitHub
parent ca0109bd68
commit 20b658607e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 8 deletions

View File

@ -13,7 +13,7 @@ streamlit
elasticsearch
nltk
pandas
datasets
datasets >= 1.1.3
fire
pytest
conllu

View File

@ -15,7 +15,8 @@
"""
Fine-tuning the library models for token classification.
"""
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments.
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
# comments.
import logging
import os
@ -24,7 +25,7 @@ from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from datasets import load_dataset
from datasets import ClassLabel, load_dataset
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
import transformers
@ -198,12 +199,17 @@ def main():
if training_args.do_train:
column_names = datasets["train"].column_names
features = datasets["train"].features
else:
column_names = datasets["validation"].column_names
text_column_name = "words" if "words" in column_names else column_names[0]
label_column_name = data_args.task_name if data_args.task_name in column_names else column_names[1]
features = datasets["validation"].features
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
label_column_name = (
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
)
# Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved)
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels.
def get_label_list(labels):
unique_labels = set()
for label in labels:
@ -212,8 +218,13 @@ def main():
label_list.sort()
return label_list
label_list = get_label_list(datasets["train"][label_column_name])
label_to_id = {l: i for i, l in enumerate(label_list)}
if isinstance(features[label_column_name].feature, ClassLabel):
label_list = features[label_column_name].feature.names
# No need to convert the labels since they are already ints.
label_to_id = {i: i for i in range(len(label_list))}
else:
label_list = get_label_list(datasets["train"][label_column_name])
label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)
# Load pretrained model and tokenizer