mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
added code for all glue task processors
This commit is contained in:
parent
9b03d67b83
commit
043c8781ef
@ -167,6 +167,16 @@ class MnliProcessor(DataProcessor):
|
|||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class MnliMismatchedProcessor(MnliProcessor):
|
||||||
|
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
|
||||||
|
"dev_matched")
|
||||||
|
|
||||||
|
|
||||||
class ColaProcessor(DataProcessor):
|
class ColaProcessor(DataProcessor):
|
||||||
"""Processor for the CoLA data set (GLUE version)."""
|
"""Processor for the CoLA data set (GLUE version)."""
|
||||||
|
|
||||||
@ -227,6 +237,170 @@ class Sst2Processor(DataProcessor):
|
|||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class StsbProcessor(DataProcessor):
|
||||||
|
"""Processor for the STS-B data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return [None]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
|
text_a = line[7]
|
||||||
|
text_b = line[8]
|
||||||
|
label = line[-1]
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class QqpProcessor(DataProcessor):
|
||||||
|
"""Processor for the STS-B data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["0", "1"]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
|
try:
|
||||||
|
text_a = line[3]
|
||||||
|
text_b = line[4]
|
||||||
|
label = line[5]
|
||||||
|
except IndexError:
|
||||||
|
continue
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class QnliProcessor(DataProcessor):
|
||||||
|
"""Processor for the STS-B data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
||||||
|
"dev_matched")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["entailment", "not_entailment"]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
|
text_a = line[1]
|
||||||
|
text_b = line[2]
|
||||||
|
label = line[-1]
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class RteProcessor(DataProcessor):
|
||||||
|
"""Processor for the RTE data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["entailment", "not_entailment"]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
|
text_a = line[1]
|
||||||
|
text_b = line[2]
|
||||||
|
label = line[-1]
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class WnliProcessor(DataProcessor):
|
||||||
|
"""Processor for the WNLI data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["0", "1"]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
|
text_a = line[1]
|
||||||
|
text_b = line[2]
|
||||||
|
label = line[-1]
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
@ -433,6 +607,23 @@ def main():
|
|||||||
"mnli": MnliProcessor,
|
"mnli": MnliProcessor,
|
||||||
"mrpc": MrpcProcessor,
|
"mrpc": MrpcProcessor,
|
||||||
"sst-2": Sst2Processor,
|
"sst-2": Sst2Processor,
|
||||||
|
"sts-b": StsbProcessor,
|
||||||
|
"qqp": QqpProcessor,
|
||||||
|
"qnli": QnliProcessor,
|
||||||
|
"rte": RteProcessor,
|
||||||
|
"wnli": WnliProcessor,
|
||||||
|
}
|
||||||
|
|
||||||
|
output_modes = {
|
||||||
|
"cola": "classification",
|
||||||
|
"mnli": "classification",
|
||||||
|
"mrpc": "classification",
|
||||||
|
"sst-2": "classification",
|
||||||
|
"sts-b": "regression",
|
||||||
|
"qqp": "classification",
|
||||||
|
"qnli": "classification",
|
||||||
|
"rte": "classification",
|
||||||
|
"wnli": "classification",
|
||||||
}
|
}
|
||||||
|
|
||||||
num_labels_task = {
|
num_labels_task = {
|
||||||
|
Loading…
Reference in New Issue
Block a user