mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1325 from huggingface/glue-included
[Proposal] GLUE processors included in library
This commit is contained in:
commit
e4022d96f7
2
.gitignore
vendored
2
.gitignore
vendored
@ -130,5 +130,5 @@ runs
|
||||
examples/runs
|
||||
|
||||
# data
|
||||
data
|
||||
/data
|
||||
serialization_dir
|
@ -46,8 +46,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_glue import (compute_metrics, convert_examples_to_features,
|
||||
output_modes, processors)
|
||||
from pytorch_transformers import glue_compute_metrics as compute_metrics
|
||||
from pytorch_transformers import glue_output_modes as output_modes
|
||||
from pytorch_transformers import glue_processors as processors
|
||||
from pytorch_transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -73,3 +73,10 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features, glue_processors, glue_tasks_num_labels)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics
|
||||
|
6
pytorch_transformers/data/__init__.py
Normal file
6
pytorch_transformers/data/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .processors import InputExample, InputFeatures, DataProcessor
|
||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
|
||||
from .metrics import is_sklearn_available
|
||||
if is_sklearn_available():
|
||||
from .metrics import glue_compute_metrics
|
83
pytorch_transformers/data/metrics/__init__.py
Normal file
83
pytorch_transformers/data/metrics/__init__.py
Normal file
@ -0,0 +1,83 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
_has_sklearn = True
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
||||
_has_sklearn = False
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
|
||||
if _has_sklearn:
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def glue_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
3
pytorch_transformers/data/processors/__init__.py
Normal file
3
pytorch_transformers/data/processors/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .utils import InputExample, InputFeatures, DataProcessor
|
||||
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
|
@ -13,79 +13,81 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
|
||||
""" GLUE processors and helpers """
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
from .utils import DataProcessor, InputExample, InputFeatures
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def glue_convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
"""
|
||||
Loads a data file into a list of `InputBatch`s
|
||||
"""
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
"""Constructs a InputExample.
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.label = label
|
||||
inputs = tokenizer.encode_plus(
|
||||
example.text_a,
|
||||
example.text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_seq_length,
|
||||
truncate_first_sequence=True # We're truncating the first sequence as a priority
|
||||
)
|
||||
input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
# Zero-pad up to the sequence length.
|
||||
padding_length = max_seq_length - len(input_ids)
|
||||
if pad_on_left:
|
||||
input_ids = ([pad_token] * padding_length) + input_ids
|
||||
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||
else:
|
||||
input_ids = input_ids + ([pad_token] * padding_length)
|
||||
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, label_id):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
if output_mode == "classification":
|
||||
label_id = label_map[example.label]
|
||||
elif output_mode == "regression":
|
||||
label_id = float(example.label)
|
||||
else:
|
||||
raise KeyError(output_mode)
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
if ex_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("guid: %s" % (example.guid))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("label: %s (id = %d)" % (example.label, label_id))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "r", encoding="utf-8-sig") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_id=label_id))
|
||||
return features
|
||||
|
||||
|
||||
class MrpcProcessor(DataProcessor):
|
||||
@ -302,7 +304,7 @@ class QnliProcessor(DataProcessor):
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
||||
"dev_matched")
|
||||
|
||||
def get_labels(self):
|
||||
@ -387,142 +389,19 @@ class WnliProcessor(DataProcessor):
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
glue_tasks_num_labels = {
|
||||
"cola": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
"sst-2": 2,
|
||||
"sts-b": 1,
|
||||
"qqp": 2,
|
||||
"qnli": 2,
|
||||
"rte": 2,
|
||||
"wnli": 2,
|
||||
}
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
"""
|
||||
Loads a data file into a list of `InputBatch`s
|
||||
"""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
example.text_a,
|
||||
example.text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_seq_length,
|
||||
truncate_first_sequence=True # We're truncating the first sequence as a priority
|
||||
)
|
||||
input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
padding_length = max_seq_length - len(input_ids)
|
||||
if pad_on_left:
|
||||
input_ids = ([pad_token] * padding_length) + input_ids
|
||||
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||
else:
|
||||
input_ids = input_ids + ([pad_token] * padding_length)
|
||||
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
if output_mode == "classification":
|
||||
label_id = label_map[example.label]
|
||||
elif output_mode == "regression":
|
||||
label_id = float(example.label)
|
||||
else:
|
||||
raise KeyError(output_mode)
|
||||
|
||||
if ex_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("guid: %s" % (example.guid))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("label: %s (id = %d)" % (example.label, label_id))
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_id=label_id))
|
||||
return features
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
processors = {
|
||||
glue_processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mnli-mm": MnliMismatchedProcessor,
|
||||
@ -535,7 +414,7 @@ processors = {
|
||||
"wnli": WnliProcessor,
|
||||
}
|
||||
|
||||
output_modes = {
|
||||
glue_output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mnli-mm": "classification",
|
||||
@ -547,15 +426,3 @@ output_modes = {
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
"cola": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
"sst-2": 2,
|
||||
"sts-b": 1,
|
||||
"qqp": 2,
|
||||
"qnli": 2,
|
||||
"rte": 2,
|
||||
"wnli": 2,
|
||||
}
|
75
pytorch_transformers/data/processors/utils.py
Normal file
75
pytorch_transformers/data/processors/utils.py
Normal file
@ -0,0 +1,75 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import sys
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.label = label
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, label_id):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "r", encoding="utf-8-sig") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
Loading…
Reference in New Issue
Block a user