mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
updated data processor and metrics
This commit is contained in:
parent
0b82e3d0d9
commit
b5ec526f85
2
.gitignore
vendored
2
.gitignore
vendored
@ -130,5 +130,5 @@ runs
|
||||
examples/runs
|
||||
|
||||
# data
|
||||
data
|
||||
/data
|
||||
serialization_dir
|
@ -46,7 +46,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from pytorch_transformers.preprocessing import (compute_metrics, output_modes, processors, convert_examples_to_glue_features)
|
||||
from pytorch_transformers import glue_compute_metrics as compute_metrics
|
||||
from pytorch_transformers import glue_output_modes as output_modes
|
||||
from pytorch_transformers import glue_processors as processors
|
||||
from pytorch_transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -275,7 +278,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
features = convert_examples_to_glue_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
||||
|
@ -73,3 +73,10 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features, glue_processors)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics
|
||||
|
6
pytorch_transformers/data/__init__.py
Normal file
6
pytorch_transformers/data/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .processors import (InputExample, InputFeatures, DataProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features, glue_processors)
|
||||
from .metrics import is_sklearn_available
|
||||
|
||||
if is_sklearn_available():
|
||||
from .metrics import glue_compute_metrics
|
83
pytorch_transformers/data/metrics/__init__.py
Normal file
83
pytorch_transformers/data/metrics/__init__.py
Normal file
@ -0,0 +1,83 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
_has_sklearn = True
|
||||
except e:
|
||||
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
||||
_has_sklearn = False
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
|
||||
if _has_sklearn:
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def glue_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
2
pytorch_transformers/data/processors/__init__.py
Normal file
2
pytorch_transformers/data/processors/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .utils import InputExample, InputFeatures, DataProcessor
|
||||
from .glue import output_modes, processors, convert_examples_to_glue_features
|
@ -15,12 +15,50 @@
|
||||
# limitations under the License.
|
||||
""" GLUE processors and helpers """
|
||||
|
||||
from .utils import DataProcessor
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .utils import DataProcessor, InputExample, InputFeatures
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
"cola": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
"sst-2": 2,
|
||||
"sts-b": 1,
|
||||
"qqp": 2,
|
||||
"qnli": 2,
|
||||
"rte": 2,
|
||||
"wnli": 2,
|
||||
}
|
||||
|
||||
processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mnli-mm": MnliMismatchedProcessor,
|
||||
"mrpc": MrpcProcessor,
|
||||
"sst-2": Sst2Processor,
|
||||
"sts-b": StsbProcessor,
|
||||
"qqp": QqpProcessor,
|
||||
"qnli": QnliProcessor,
|
||||
"rte": RteProcessor,
|
||||
"wnli": WnliProcessor,
|
||||
}
|
||||
|
||||
output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mnli-mm": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
||||
|
||||
def convert_examples_to_glue_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode,
|
||||
@ -91,37 +129,6 @@ def convert_examples_to_glue_features(examples, label_list, max_seq_length,
|
||||
return features
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.label = label
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, label_id):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
|
||||
|
||||
class MrpcProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
@ -420,15 +427,3 @@ class WnliProcessor(DataProcessor):
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
"cola": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
"sst-2": 2,
|
||||
"sts-b": 1,
|
||||
"qqp": 2,
|
||||
"qnli": 2,
|
||||
"rte": 2,
|
||||
"wnli": 2,
|
||||
}
|
@ -17,8 +17,34 @@
|
||||
import csv
|
||||
import sys
|
||||
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.label = label
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, label_id):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
@ -47,53 +73,3 @@ class DataProcessor(object):
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
@ -1,56 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .glue import (ColaProcessor,
|
||||
MnliProcessor,
|
||||
MnliMismatchedProcessor,
|
||||
MrpcProcessor,
|
||||
Sst2Processor,
|
||||
StsbProcessor,
|
||||
QqpProcessor,
|
||||
QnliProcessor,
|
||||
RteProcessor,
|
||||
WnliProcessor,
|
||||
convert_examples_to_glue_features,
|
||||
)
|
||||
|
||||
from .utils import DataProcessor, simple_accuracy, acc_and_f1, pearson_and_spearman, compute_metrics
|
||||
|
||||
processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mnli-mm": MnliMismatchedProcessor,
|
||||
"mrpc": MrpcProcessor,
|
||||
"sst-2": Sst2Processor,
|
||||
"sts-b": StsbProcessor,
|
||||
"qqp": QqpProcessor,
|
||||
"qnli": QnliProcessor,
|
||||
"rte": RteProcessor,
|
||||
"wnli": WnliProcessor,
|
||||
}
|
||||
|
||||
output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mnli-mm": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
Loading…
Reference in New Issue
Block a user