mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add multiple choice to robreta and xlnet, test on swag, roberta=0.82.28
, xlnet=0.80
This commit is contained in:
parent
e384ae2b9d
commit
5582bc4b23
495
examples/single_model_scripts/run_multiple_choice.py
Normal file
495
examples/single_model_scripts/run_multiple_choice.py
Normal file
@ -0,0 +1,495 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForMultipleChoice, BertTokenizer,
|
||||
XLNetConfig, XLNetForMultipleChoice,
|
||||
XLNetTokenizer, RobertaConfig,
|
||||
RobertaForMultipleChoice, RobertaTokenizer)
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_multiple_choice import (convert_examples_to_features, processors)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer)
|
||||
}
|
||||
|
||||
def select_field(features, field):
|
||||
return [
|
||||
[
|
||||
choice[field]
|
||||
for choice in feature.choices_features
|
||||
]
|
||||
for feature in features
|
||||
]
|
||||
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
scheduler.step() # Update learning rate schedule
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step))
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_vocabulary(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = np.argmax(preds, axis=1)
|
||||
acc = simple_accuracy(preds, out_label_ids)
|
||||
result = {"eval_acc": acc, "eval_loss": eval_loss}
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task]()
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
logger.info("Training number: %s", str(len(examples)))
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer,
|
||||
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args.model_type in ['roberta']),
|
||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name]()
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
if not args.do_train:
|
||||
args.output_dir = args.model_name_or_path
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
421
examples/single_model_scripts/utils_multiple_choice.py
Normal file
421
examples/single_model_scripts/utils_multiple_choice.py
Normal file
@ -0,0 +1,421 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
import json
|
||||
import csv
|
||||
import glob
|
||||
import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for multiple choice"""
|
||||
|
||||
def __init__(self, example_id, question, contexts, endings, label=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.example_id = example_id
|
||||
self.question = question
|
||||
self.contexts = contexts
|
||||
self.endings = endings
|
||||
self.label = label
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(self,
|
||||
example_id,
|
||||
choices_features,
|
||||
label
|
||||
|
||||
):
|
||||
self.example_id = example_id
|
||||
self.choices_features = [
|
||||
{
|
||||
'input_ids': input_ids,
|
||||
'input_mask': input_mask,
|
||||
'segment_ids': segment_ids
|
||||
}
|
||||
for _, input_ids, input_mask, segment_ids in choices_features
|
||||
]
|
||||
self.label = label
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RaceProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
high = os.path.join(data_dir, 'train/high')
|
||||
middle = os.path.join(data_dir, 'train/middle')
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, 'train')
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
high = os.path.join(data_dir, 'dev/high')
|
||||
middle = os.path.join(data_dir, 'dev/middle')
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, 'dev')
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_txt(self, input_dir):
|
||||
lines = []
|
||||
files = glob.glob(input_dir + "/*txt")
|
||||
for file in tqdm.tqdm(files, desc="read files"):
|
||||
with open(file, 'r', encoding='utf-8') as fin:
|
||||
data_raw = json.load(fin)
|
||||
data_raw["race_id"] = file
|
||||
lines.append(data_raw)
|
||||
return lines
|
||||
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (_, data_raw) in enumerate(lines):
|
||||
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
||||
article = data_raw["article"]
|
||||
for i in range(len(data_raw["answers"])):
|
||||
truth = str(ord(data_raw['answers'][i]) - ord('A'))
|
||||
question = data_raw['questions'][i]
|
||||
options = data_raw['options'][i]
|
||||
|
||||
examples.append(
|
||||
InputExample(
|
||||
example_id=race_id,
|
||||
question=question,
|
||||
contexts=[article, article, article, article],
|
||||
endings=[options[0], options[1], options[2], options[3]],
|
||||
label=truth))
|
||||
return examples
|
||||
|
||||
class SwagProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def _create_examples(self, lines, type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
if type == "train" and lines[0][-1] != 'label':
|
||||
raise ValueError(
|
||||
"For training, the input file must contain a label column."
|
||||
)
|
||||
|
||||
examples = [
|
||||
InputExample(
|
||||
example_id=line[2],
|
||||
question=line[5], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
contexts = [line[4], line[4], line[4], line[4]],
|
||||
endings = [line[7], line[8], line[9], line[10]],
|
||||
label=line[11]
|
||||
) for line in lines[1:] # we skip the line with the column names
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
class ArcProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "train.jsonl")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_json(self, input_file):
|
||||
with open(input_file, 'r', encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
return lines
|
||||
|
||||
|
||||
def _create_examples(self, lines, type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
||||
def normalize(truth):
|
||||
if truth in "ABCD":
|
||||
return ord(truth) - ord("A")
|
||||
elif truth in "1234":
|
||||
return int(truth) - 1
|
||||
else:
|
||||
logger.info("truth ERROR!")
|
||||
examples = []
|
||||
three_choice = 0
|
||||
four_choice = 0
|
||||
five_choice = 0
|
||||
other_choices = 0
|
||||
for line in tqdm.tqdm(lines, desc="read arc data"):
|
||||
data_raw = json.loads(line.strip("\n"))
|
||||
if len(data_raw["question"]["choices"]) == 3:
|
||||
three_choice += 1
|
||||
continue
|
||||
elif len(data_raw["question"]["choices"]) == 5:
|
||||
five_choice += 1
|
||||
continue
|
||||
elif len(data_raw["question"]["choices"]) != 4:
|
||||
other_choices += 1
|
||||
continue
|
||||
four_choice += 1
|
||||
truth = str(normalize(data_raw["answerKey"]))
|
||||
question_choices = data_raw["question"]
|
||||
question = question_choices["stem"]
|
||||
id = data_raw["id"]
|
||||
options = question_choices["choices"]
|
||||
|
||||
if len(options) == 4:
|
||||
examples.append(
|
||||
InputExample(
|
||||
example_id = id,
|
||||
question=question,
|
||||
contexts=[options[0]["para"].replace("_", ""), options[1]["para"].replace("_", ""),
|
||||
options[2]["para"].replace("_", ""), options[3]["para"].replace("_", "")],
|
||||
endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]],
|
||||
label=truth))
|
||||
|
||||
if type == "train":
|
||||
assert len(examples) > 1
|
||||
assert examples[0].label is not None
|
||||
logger.info("len examples: %s}", str(len(examples)))
|
||||
logger.info("Three choices: %s", str(three_choice))
|
||||
logger.info("Five choices: %s", str(five_choice))
|
||||
logger.info("Other choices: %s", str(other_choices))
|
||||
logger.info("four choices: %s", str(four_choice))
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=False,
|
||||
cls_token='[CLS]',
|
||||
cls_token_segment_id=1,
|
||||
sep_token='[SEP]',
|
||||
sequence_a_segment_id=0,
|
||||
sequence_b_segment_id=1,
|
||||
sep_token_extra=False,
|
||||
pad_token_segment_id=0,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
mask_padding_with_zero=True):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||
"""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
choices_features = []
|
||||
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
||||
tokens_a = tokenizer.tokenize(context)
|
||||
tokens_b = None
|
||||
if example.question.find("_") != -1:
|
||||
tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
|
||||
else:
|
||||
tokens_b = tokenizer.tokenize(example.question + " " + ending)
|
||||
special_tokens_count = 4 if sep_token_extra else 3
|
||||
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||
# (b) For single sequences:
|
||||
# tokens: [CLS] the dog is hairy . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0
|
||||
#
|
||||
# Where "type_ids" are used to indicate whether this is the first
|
||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||
# embedding vector (and position vector). This is not *strictly* necessary
|
||||
# since the [SEP] token unambiguously separates the sequences, but it makes
|
||||
# it easier for the model to learn the concept of sequences.
|
||||
#
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = tokens_a + [sep_token]
|
||||
if sep_token_extra:
|
||||
# roberta uses an extra separator b/w pairs of sentences
|
||||
tokens += [sep_token]
|
||||
|
||||
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||
|
||||
if tokens_b:
|
||||
tokens += tokens_b + [sep_token]
|
||||
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
|
||||
|
||||
if cls_token_at_end:
|
||||
tokens = tokens + [cls_token]
|
||||
segment_ids = segment_ids + [cls_token_segment_id]
|
||||
else:
|
||||
tokens = [cls_token] + tokens
|
||||
segment_ids = [cls_token_segment_id] + segment_ids
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
padding_length = max_seq_length - len(input_ids)
|
||||
if pad_on_left:
|
||||
input_ids = ([pad_token] * padding_length) + input_ids
|
||||
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||
else:
|
||||
input_ids = input_ids + ([pad_token] * padding_length)
|
||||
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
choices_features.append((tokens, input_ids, input_mask, segment_ids))
|
||||
label = label_map[example.label]
|
||||
|
||||
if ex_index < 2:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("race_id: {}".format(example.example_id))
|
||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info("tokens: {}".format(' '.join(tokens)))
|
||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
||||
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
|
||||
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
|
||||
logger.info("label: {}".format(label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id = example.example_id,
|
||||
choices_features = choices_features,
|
||||
label = label
|
||||
)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
processors = {
|
||||
"race": RaceProcessor,
|
||||
"swag": SwagProcessor,
|
||||
"arc": ArcProcessor
|
||||
}
|
||||
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
"race", 4,
|
||||
"swag", 4,
|
||||
"arc", 4
|
||||
}
|
@ -31,7 +31,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2PreTrainedModel, GPT2Model,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlnet import (XLNetConfig,
|
||||
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNetForMultipleChoice,
|
||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
|
||||
@ -39,6 +39,7 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
|
||||
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||
RobertaForMultipleChoice,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||
|
@ -329,6 +329,46 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class RobertaForMultipleChoice(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForMultipleChoice, self).__init__(config)
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
outputs = self.roberta(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask, head_mask=head_mask)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
@ -1143,6 +1143,50 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
|
||||
|
||||
class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
r"""
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLNetForMultipleChoice, self).__init__(config)
|
||||
|
||||
self.transformer = XLNetModel(config)
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
self.logits_proj = nn.Linear(config.d_model, 1)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None,
|
||||
labels=None, head_mask=None):
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
flat_input_mask = input_mask.view(-1, input_mask.size(-1) if input_mask is not None else None)
|
||||
|
||||
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
|
||||
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
|
||||
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
|
||||
head_mask=head_mask)
|
||||
|
||||
|
||||
output = transformer_outputs[0]
|
||||
|
||||
output = self.sequence_summary(output)
|
||||
logits = self.logits_proj(output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
outputs = (reshaped_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
|
Loading…
Reference in New Issue
Block a user