From 043c8781efcfdb28b8d08341c0564304d08efcfd Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Thu, 14 Mar 2019 04:24:04 -0400 Subject: [PATCH] added code for all glue task processors --- examples/run_classifier.py | 191 +++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index f023a4cf20f..972c110396d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -167,6 +167,16 @@ class MnliProcessor(DataProcessor): return examples +class MnliMismatchedProcessor(MnliProcessor): + """Processor for the MultiNLI Mismatched data set (GLUE version).""" + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), + "dev_matched") + + class ColaProcessor(DataProcessor): """Processor for the CoLA data set (GLUE version).""" @@ -227,6 +237,170 @@ class Sst2Processor(DataProcessor): return examples +class StsbProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return [None] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[7] + text_b = line[8] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QqpProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + try: + text_a = line[3] + text_b = line[4] + label = line[5] + except IndexError: + continue + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QnliProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), + "dev_matched") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class RteProcessor(DataProcessor): + """Processor for the RTE data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class WnliProcessor(DataProcessor): + """Processor for the WNLI data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): """Loads a data file into a list of `InputBatch`s.""" @@ -433,6 +607,23 @@ def main(): "mnli": MnliProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, + "sts-b": StsbProcessor, + "qqp": QqpProcessor, + "qnli": QnliProcessor, + "rte": RteProcessor, + "wnli": WnliProcessor, + } + + output_modes = { + "cola": "classification", + "mnli": "classification", + "mrpc": "classification", + "sst-2": "classification", + "sts-b": "regression", + "qqp": "classification", + "qnli": "classification", + "rte": "classification", + "wnli": "classification", } num_labels_task = {