diff --git a/examples/run_glue.py b/examples/run_glue.py index 1a51255c110..954a8fbf0c5 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -52,6 +52,9 @@ from transformers import (WEIGHTS_NAME, BertConfig, AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer, + XLMRobertaConfig, + XLMRobertaForSequenceClassification, + XLMRobertaTokenizer, ) from transformers import AdamW, get_linear_schedule_with_warmup @@ -72,7 +75,8 @@ MODEL_CLASSES = { 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), - 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer) + 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), + 'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer), } @@ -304,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): else: logger.info("Creating features from dataset file at %s", args.data_dir) label_list = processor.get_labels() - if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: + if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']: # HACK(label indices are swapped in RoBERTa pretrained model) - label_list[1], label_list[2] = label_list[2], label_list[1] + label_list[1], label_list[2] = label_list[2], label_list[1] examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) features = convert_examples_to_features(examples, tokenizer,