examples: add XLM-RoBERTa to glue script

This commit is contained in:
Stefan Schweter 2019-12-19 02:23:01 +01:00
parent fe9aab1055
commit a26ce4dee1

View File

@ -52,6 +52,9 @@ from transformers import (WEIGHTS_NAME, BertConfig,
AlbertConfig, AlbertConfig,
AlbertForSequenceClassification, AlbertForSequenceClassification,
AlbertTokenizer, AlbertTokenizer,
XLMRobertaConfig,
XLMRobertaForSequenceClassification,
XLMRobertaTokenizer,
) )
from transformers import AdamW, get_linear_schedule_with_warmup from transformers import AdamW, get_linear_schedule_with_warmup
@ -72,7 +75,8 @@ MODEL_CLASSES = {
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer) 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
} }
@ -304,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
else: else:
logger.info("Creating features from dataset file at %s", args.data_dir) logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels() label_list = processor.get_labels()
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
# HACK(label indices are swapped in RoBERTa pretrained model) # HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1] label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, features = convert_examples_to_features(examples,
tokenizer, tokenizer,