Run classifier processor for SST-2.

This commit is contained in:
John Lehmann 2019-03-05 13:38:28 -06:00
parent 2152bfeae8
commit 0f96d4b1f7

View File

@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor):
return examples
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
@ -401,10 +432,12 @@ def main():
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
}
num_labels_task = {
"cola": 2,
"sst-2": 2,
"mnli": 3,
"mrpc": 2,
}