mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
parent
ca0109bd68
commit
20b658607e
@ -13,7 +13,7 @@ streamlit
|
||||
elasticsearch
|
||||
nltk
|
||||
pandas
|
||||
datasets
|
||||
datasets >= 1.1.3
|
||||
fire
|
||||
pytest
|
||||
conllu
|
||||
|
@ -15,7 +15,8 @@
|
||||
"""
|
||||
Fine-tuning the library models for token classification.
|
||||
"""
|
||||
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments.
|
||||
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
|
||||
# comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
@ -24,7 +25,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from datasets import ClassLabel, load_dataset
|
||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||
|
||||
import transformers
|
||||
@ -198,12 +199,17 @@ def main():
|
||||
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
features = datasets["train"].features
|
||||
else:
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "words" if "words" in column_names else column_names[0]
|
||||
label_column_name = data_args.task_name if data_args.task_name in column_names else column_names[1]
|
||||
features = datasets["validation"].features
|
||||
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
|
||||
label_column_name = (
|
||||
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
|
||||
)
|
||||
|
||||
# Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved)
|
||||
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||
# unique labels.
|
||||
def get_label_list(labels):
|
||||
unique_labels = set()
|
||||
for label in labels:
|
||||
@ -212,8 +218,13 @@ def main():
|
||||
label_list.sort()
|
||||
return label_list
|
||||
|
||||
label_list = get_label_list(datasets["train"][label_column_name])
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
||||
label_list = features[label_column_name].feature.names
|
||||
# No need to convert the labels since they are already ints.
|
||||
label_to_id = {i: i for i in range(len(label_list))}
|
||||
else:
|
||||
label_list = get_label_list(datasets["train"][label_column_name])
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
|
Loading…
Reference in New Issue
Block a user