mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
examples: add XLM-RoBERTa to glue script
This commit is contained in:
parent
fe9aab1055
commit
a26ce4dee1
@ -52,6 +52,9 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
|||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
AlbertForSequenceClassification,
|
AlbertForSequenceClassification,
|
||||||
AlbertTokenizer,
|
AlbertTokenizer,
|
||||||
|
XLMRobertaConfig,
|
||||||
|
XLMRobertaForSequenceClassification,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
@ -72,7 +75,8 @@ MODEL_CLASSES = {
|
|||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||||
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
|
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
||||||
|
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -304,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
|
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
|
||||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||||
features = convert_examples_to_features(examples,
|
features = convert_examples_to_features(examples,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user