From a26ce4dee116a1d5d9099c8a94e22d1e31ad631c Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Thu, 19 Dec 2019 02:23:01 +0100 Subject: [PATCH] examples: add XLM-RoBERTa to glue script --- examples/run_glue.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index 1a51255c110..954a8fbf0c5 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -52,6 +52,9 @@ from transformers import (WEIGHTS_NAME, BertConfig, AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer, + XLMRobertaConfig, + XLMRobertaForSequenceClassification, + XLMRobertaTokenizer, ) from transformers import AdamW, get_linear_schedule_with_warmup @@ -72,7 +75,8 @@ MODEL_CLASSES = { 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), - 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer) + 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), + 'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer), } @@ -304,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): else: logger.info("Creating features from dataset file at %s", args.data_dir) label_list = processor.get_labels() - if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: + if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']: # HACK(label indices are swapped in RoBERTa pretrained model) - label_list[1], label_list[2] = label_list[2], label_list[1] + label_list[1], label_list[2] = label_list[2], label_list[1] examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) features = convert_examples_to_features(examples, tokenizer,