From 043c8781efcfdb28b8d08341c0564304d08efcfd Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Thu, 14 Mar 2019 04:24:04 -0400 Subject: [PATCH 1/5] added code for all glue task processors --- examples/run_classifier.py | 191 +++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index f023a4cf20f..972c110396d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -167,6 +167,16 @@ class MnliProcessor(DataProcessor): return examples +class MnliMismatchedProcessor(MnliProcessor): + """Processor for the MultiNLI Mismatched data set (GLUE version).""" + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), + "dev_matched") + + class ColaProcessor(DataProcessor): """Processor for the CoLA data set (GLUE version).""" @@ -227,6 +237,170 @@ class Sst2Processor(DataProcessor): return examples +class StsbProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return [None] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[7] + text_b = line[8] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QqpProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + try: + text_a = line[3] + text_b = line[4] + label = line[5] + except IndexError: + continue + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QnliProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), + "dev_matched") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class RteProcessor(DataProcessor): + """Processor for the RTE data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class WnliProcessor(DataProcessor): + """Processor for the WNLI data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): """Loads a data file into a list of `InputBatch`s.""" @@ -433,6 +607,23 @@ def main(): "mnli": MnliProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, + "sts-b": StsbProcessor, + "qqp": QqpProcessor, + "qnli": QnliProcessor, + "rte": RteProcessor, + "wnli": WnliProcessor, + } + + output_modes = { + "cola": "classification", + "mnli": "classification", + "mrpc": "classification", + "sst-2": "classification", + "sts-b": "regression", + "qqp": "classification", + "qnli": "classification", + "rte": "classification", + "wnli": "classification", } num_labels_task = { From 4c721c6b6aba636372b0c9e0ba449e182c6488d5 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Fri, 15 Mar 2019 23:21:24 -0400 Subject: [PATCH 2/5] added eval time metrics for GLUE tasks --- examples/run_classifier.py | 159 ++++++++++++++++++++++++++++--------- 1 file changed, 123 insertions(+), 36 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 972c110396d..51d57be4dcf 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -31,6 +31,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange +from torch.nn import CrossEntropyLoss, MSELoss +from scipy.stats import pearsonr, spearmanr +from sklearn.metrics import matthews_corrcoef, f1_score + from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.tokenization import BertTokenizer @@ -401,13 +405,17 @@ class WnliProcessor(DataProcessor): return examples -def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): +def convert_examples_to_features(examples, label_list, max_seq_length, + tokenizer, output_mode): """Loads a data file into a list of `InputBatch`s.""" label_map = {label : i for i, label in enumerate(label_list)} features = [] for (ex_index, example) in enumerate(examples): + if ex_index % 10000 == 0: + logger.info("Writing example %d of %d" % (ex_index, len(examples))) + tokens_a = tokenizer.tokenize(example.text_a) tokens_b = None @@ -463,7 +471,13 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length - label_id = label_map[example.label] + if output_mode == "classification": + label_id = label_map[example.label] + elif output_mode == "regression": + label_id = float(example.label) + else: + raise KeyError(output_mode) + if ex_index < 5: logger.info("*** Example ***") logger.info("guid: %s" % (example.guid)) @@ -499,9 +513,56 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): else: tokens_b.pop() -def accuracy(out, labels): - outputs = np.argmax(out, axis=1) - return np.sum(outputs == labels) + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def acc_and_f1(preds, labels): + acc = simple_accuracy(preds, labels) + f1 = f1_score(y_true=labels, y_pred=preds) + return { + "acc": acc, + "f1": f1, + "acc_and_f1": (acc + f1) / 2, + } + + +def pearson_and_spearman(preds, labels): + pearson_corr = pearsonr(preds, labels)[0] + spearman_corr = spearmanr(preds, labels)[0] + return { + "pearson": pearson_corr, + "spearmanr": spearman_corr, + "corr": (pearson_corr + spearman_corr) / 2, + } + + +def compute_metrics(task_name, preds, labels): + assert len(preds) == len(labels) + if task_name == "cola": + return {"mcc": matthews_corrcoef(labels, preds)} + elif task_name == "sst-2": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "mrpc": + return acc_and_f1(preds, labels) + elif task_name == "sts-b": + return pearson_and_spearman(preds, labels) + elif task_name == "qqp": + return acc_and_f1(preds, labels) + elif task_name == "mnli": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "mnli-mm": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "qnli": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "rte": + return {"acc": simple_accuracy(preds, labels)} + elif task_name == "wnli": + return {"acc": simple_accuracy(preds, labels)} + else: + raise KeyError(task_name) + def main(): parser = argparse.ArgumentParser() @@ -605,6 +666,7 @@ def main(): processors = { "cola": ColaProcessor, "mnli": MnliProcessor, + "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, @@ -617,6 +679,7 @@ def main(): output_modes = { "cola": "classification", "mnli": "classification", + "mnli-mm": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", @@ -626,13 +689,6 @@ def main(): "wnli": "classification", } - num_labels_task = { - "cola": 2, - "sst-2": 2, - "mnli": 3, - "mrpc": 2, - } - if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() @@ -671,8 +727,10 @@ def main(): raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() - num_labels = num_labels_task[task_name] + output_mode = output_modes[task_name] + label_list = processor.get_labels() + num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) @@ -689,7 +747,7 @@ def main(): cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) model = BertForSequenceClassification.from_pretrained(args.bert_model, cache_dir=cache_dir, - num_labels = num_labels) + num_labels=num_labels) if args.fp16: model.half() model.to(device) @@ -737,7 +795,7 @@ def main(): tr_loss = 0 if args.do_train: train_features = convert_examples_to_features( - train_examples, label_list, args.max_seq_length, tokenizer) + train_examples, label_list, args.max_seq_length, tokenizer, output_mode) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size) @@ -745,7 +803,12 @@ def main(): all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) - all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) + + if output_mode == "classification": + all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) + elif output_mode == "regression": + all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float) + train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if args.local_rank == -1: train_sampler = RandomSampler(train_data) @@ -760,7 +823,17 @@ def main(): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch - loss = model(input_ids, segment_ids, input_mask, label_ids) + + # define a new function to compute loss values for both output_modes + logits = model(input_ids, segment_ids, input_mask, labels=None) + + if output_mode == "classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) + elif output_mode == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), label_ids.view(-1)) + if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: @@ -805,22 +878,28 @@ def main(): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( - eval_examples, label_list, args.max_seq_length, tokenizer) + eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) - all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) + + if output_mode == "classification": + all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) + elif output_mode == "regression": + all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float) + eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) model.eval() - eval_loss, eval_accuracy = 0, 0 - nb_eval_steps, nb_eval_examples = 0, 0 + eval_loss = 0 + nb_eval_steps = 0 + preds = [] for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): input_ids = input_ids.to(device) @@ -829,26 +908,34 @@ def main(): label_ids = label_ids.to(device) with torch.no_grad(): - tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) - logits = model(input_ids, segment_ids, input_mask) - - logits = logits.detach().cpu().numpy() - label_ids = label_ids.to('cpu').numpy() - tmp_eval_accuracy = accuracy(logits, label_ids) - + logits = model(input_ids, segment_ids, input_mask, labels=None) + + # create eval loss and other metric required by the task + if output_mode == "classification": + loss_fct = CrossEntropyLoss() + tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) + elif output_mode == "regression": + loss_fct = MSELoss() + tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1)) + eval_loss += tmp_eval_loss.mean().item() - eval_accuracy += tmp_eval_accuracy - - nb_eval_examples += input_ids.size(0) nb_eval_steps += 1 + if len(preds) == 0: + preds.append(logits.detach().cpu().numpy()) + else: + preds[0] = np.append( + preds[0], logits.detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps - eval_accuracy = eval_accuracy / nb_eval_examples + preds = preds[0] + if output_mode == "classification": + preds = np.argmax(preds, axis=1) + result = compute_metrics(task_name, preds, all_label_ids.numpy()) loss = tr_loss/nb_tr_steps if args.do_train else None - result = {'eval_loss': eval_loss, - 'eval_accuracy': eval_accuracy, - 'global_step': global_step, - 'loss': loss} + + result['eval_loss'] = eval_loss + result['global_step'] = global_step + result['loss'] = loss output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: From e0bf01d9a925e2a2280fea77c48ebd696b66e83e Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Sat, 16 Mar 2019 14:10:48 -0400 Subject: [PATCH 3/5] added hack for mismatched MNLI --- examples/run_classifier.py | 66 +++++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 51d57be4dcf..588e876af2b 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -679,7 +679,6 @@ def main(): output_modes = { "cola": "classification", "mnli": "classification", - "mnli-mm": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", @@ -930,6 +929,8 @@ def main(): preds = preds[0] if output_mode == "classification": preds = np.argmax(preds, axis=1) + elif output_mode == "regression": + preds = np.squeeze(preds) result = compute_metrics(task_name, preds, all_label_ids.numpy()) loss = tr_loss/nb_tr_steps if args.do_train else None @@ -943,6 +944,69 @@ def main(): for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) + + # hack for MNLI-MM + if task_name == "mnli": + task_name = "mnli-mm" + processor = processors[task_name]() + + eval_examples = processor.get_dev_examples(args.data_dir) + eval_features = convert_examples_to_features( + eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) + logger.info("***** Running evaluation *****") + logger.info(" Num examples = %d", len(eval_examples)) + logger.info(" Batch size = %d", args.eval_batch_size) + all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) + all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) + + eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) + # Run prediction for full data + eval_sampler = SequentialSampler(eval_data) + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) + + model.eval() + eval_loss = 0 + nb_eval_steps = 0 + preds = [] + + for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): + input_ids = input_ids.to(device) + input_mask = input_mask.to(device) + segment_ids = segment_ids.to(device) + label_ids = label_ids.to(device) + + with torch.no_grad(): + logits = model(input_ids, segment_ids, input_mask, labels=None) + + loss_fct = CrossEntropyLoss() + tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) + + eval_loss += tmp_eval_loss.mean().item() + nb_eval_steps += 1 + if len(preds) == 0: + preds.append(logits.detach().cpu().numpy()) + else: + preds[0] = np.append( + preds[0], logits.detach().cpu().numpy(), axis=0) + + eval_loss = eval_loss / nb_eval_steps + preds = preds[0] + preds = np.argmax(preds, axis=1) + result = compute_metrics(task_name, preds, all_label_ids.numpy()) + loss = tr_loss/nb_tr_steps if args.do_train else None + + result['eval_loss'] = eval_loss + result['global_step'] = global_step + result['loss'] = loss + + output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt") + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) if __name__ == "__main__": main() From 8a4e90ff40e217d38311737ef4ab43531b6397c4 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Sun, 17 Mar 2019 08:16:50 -0400 Subject: [PATCH 4/5] corrected folder creation error for MNLI-MM, verified GLUE results --- examples/run_classifier.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 588e876af2b..fb78b6efeb7 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -908,7 +908,7 @@ def main(): with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask, labels=None) - + # create eval loss and other metric required by the task if output_mode == "classification": loss_fct = CrossEntropyLoss() @@ -944,12 +944,17 @@ def main(): for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) - + # hack for MNLI-MM if task_name == "mnli": task_name = "mnli-mm" processor = processors[task_name]() + if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train: + raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) + if not os.path.exists(args.output_dir + '-MM'): + os.makedirs(args.output_dir + '-MM') + eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) @@ -990,7 +995,7 @@ def main(): else: preds[0] = np.append( preds[0], logits.detach().cpu().numpy(), axis=0) - + eval_loss = eval_loss / nb_eval_steps preds = preds[0] preds = np.argmax(preds, axis=1) From f471979167b399ea0e3af7fa39bc960e954f85f5 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Thu, 21 Mar 2019 15:38:30 -0400 Subject: [PATCH 5/5] added GLUE dev set results and details on how to run GLUE tasks --- README.md | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index afc2b6efdad..98fd27badb3 100644 --- a/README.md +++ b/README.md @@ -927,11 +927,60 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/): -- a *sequence-level classifier* on the MRPC classification corpus, +- a *sequence-level classifier* on nine different GLUE tasks, - a *token-level classifier* on the question answering dataset SQuAD, and - a *sequence-level multiple-choice classifier* on the SWAG classification corpus. - a *BERT language model* on another target corpus +#### GLUE results on dev set + +We get the following results on the dev set of GLUE benchmark with an uncased BERT base +model. All experiments were run on a P100 GPU with a batch size of 32. + +| Task | Metric | Result | +|-|-|-| +| CoLA | Matthew's corr. | 57.29 | +| SST-2 | accuracy | 93.00 | +| MRPC | F1/accuracy | 88.85/83.82 | +| STS-B | Pearson/Spearman corr. | 89.70/89.37 | +| QQP | accuracy/F1 | 90.72/87.41 | +| MNLI | matched acc./mismatched acc.| 83.95/84.39 | +| QNLI | accuracy | 89.04 | +| RTE | accuracy | 61.01 | +| WNLI | accuracy | 53.52 | + +Some of these results are significantly different from the ones reported on the test set +of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite. + +Before running anyone of these GLUE tasks you should download the +[GLUE data](https://gluebenchmark.com/tasks) by running +[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) +and unpack it to some directory `$GLUE_DIR`. + +```shell +export GLUE_DIR=/path/to/glue +export TASK_NAME=MRPC + +python run_classifier.py \ + --task_name $TASK_NAME \ + --do_train \ + --do_eval \ + --do_lower_case \ + --data_dir $GLUE_DIR/$TASK_NAME \ + --bert_model bert-base-uncased \ + --max_seq_length 128 \ + --train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /tmp/$TASK_NAME/ +``` + +where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI. + +The dev set results will be present within the text file 'eval_results.txt' in the specified output_dir. In case of MNLI, since there are two separate dev sets, matched and mismatched, there will be a separate output folder called '/tmp/MNLI-MM/' in addition to '/tmp/MNLI/'. + +The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI, CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being said, there shouldn't be any issues in running half-precision training with the remaining GLUE tasks as well, since the data processor for each task inherits from the base class DataProcessor. + #### MRPC This example code fine-tunes BERT on the Microsoft Research Paraphrase