mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Run classifier processor for SST-2.
This commit is contained in:
parent
2152bfeae8
commit
0f96d4b1f7
@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor):
|
||||
return examples
|
||||
|
||||
|
||||
class Sst2Processor(DataProcessor):
|
||||
"""Processor for the SST-2 data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = line[0]
|
||||
label = line[1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
@ -401,10 +432,12 @@ def main():
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mrpc": MrpcProcessor,
|
||||
"sst-2": Sst2Processor,
|
||||
}
|
||||
|
||||
num_labels_task = {
|
||||
"cola": 2,
|
||||
"sst-2": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
}
|
||||
@ -597,7 +630,7 @@ def main():
|
||||
model.eval()
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user