Merge branch 'refs/heads/squad_roberta'

# Conflicts:
#	transformers/data/processors/squad.py
This commit is contained in:
erenup 2019-12-14 08:53:59 +08:00
commit c7780700f5
6 changed files with 424 additions and 185 deletions

View File

@ -39,6 +39,7 @@ from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer,
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering,
@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'start_positions': batch[3],
'end_positions': batch[4]
'end_positions': batch[4],
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad():
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1]
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions
@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)
@ -363,7 +360,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
return_dataset='pt'
return_dataset='pt',
threads=args.threads,
)
if args.local_rank in [-1, 0]:
@ -481,6 +479,8 @@ def main():
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--threads', type=int, default=1, help='multiple threads for converting example to features')
args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:

130
srl_label.txt Normal file
View File

@ -0,0 +1,130 @@
O
I-ARG1
I-ARG2
I-ARG0
B-V
B-ARG1
B-ARG0
I-ARGM-ADV
I-ARGM-TMP
B-ARG2
I-ARGM-LOC
I-ARGM-MNR
B-ARGM-TMP
I-ARGM-CAU
I-ARGM-PRP
B-ARGM-MOD
I-C-ARG1
B-ARGM-ADV
I-ARGM-PRD
B-ARGM-DIS
I-ARG3
I-V
I-ARG4
B-ARGM-MNR
B-ARGM-LOC
I-ARGM-NEG
B-ARGM-NEG
B-R-ARG0
I-ARGM-DIR
I-ARGM-DIS
I-ARGM-PNC
I-ARGM-ADJ
B-R-ARG1
B-ARG3
B-ARGM-PRP
B-ARG4
I-ARGM-GOL
I-R-ARG0
B-ARGM-CAU
B-ARGM-DIR
B-ARGM-PRD
I-ARGM-EXT
B-C-ARG1
B-ARGM-ADJ
I-C-ARG0
B-ARGM-EXT
I-C-ARG2
I-ARGM-COM
I-R-ARG1
I-ARGM-MOD
B-ARGM-GOL
B-ARGM-PNC
B-R-ARGM-LOC
B-R-ARGM-TMP
B-ARGM-LVB
B-ARGM-COM
B-R-ARG2
I-C-ARGM-MNR
B-C-ARG0
I-R-ARGM-LOC
B-C-ARG2
I-C-ARGM-EXT
I-C-ARG4
B-ARGM-REC
I-R-ARG2
I-C-ARGM-TMP
I-ARG5
I-C-ARG3
I-C-ARGM-ADV
B-ARG5
B-R-ARGM-MNR
I-ARGM-DSP
I-C-ARGM-LOC
B-R-ARG3
I-ARGA
I-R-ARGM-MNR
B-R-ARGM-CAU
I-R-ARGM-TMP
B-C-ARGM-MNR
B-ARGA
I-C-ARGM-DSP
B-C-ARGM-ADV
I-R-ARG3
B-R-ARGM-ADV
B-C-ARG4
I-C-ARGM-CAU
B-C-ARGM-EXT
B-C-ARGM-TMP
B-R-ARGM-DIR
B-R-ARG4
I-R-ARGM-ADV
I-ARGM-REC
B-C-ARG3
B-C-ARGM-LOC
B-R-ARGM-EXT
B-ARGM-PRR
B-R-ARGM-PRP
B-ARGM-PRX
I-R-ARGM-DIR
I-R-ARGM-EXT
I-C-ARGM-NEG
B-ARGM-DSP
B-R-ARGM-GOL
I-R-ARGM-GOL
I-R-ARGM-PNC
I-C-ARGM-PRP
B-R-ARGM-COM
I-R-ARGM-PRP
I-C-ARGM-COM
B-C-ARGM-CAU
B-C-ARGM-DSP
I-R-ARGM-COM
I-R-ARGM-CAU
B-R-ARGM-PNC
I-C-ARGM-DIS
I-C-ARGM-DIR
I-R-ARG4
B-R-ARGM-PRD
I-R-ARGM-PRD
B-C-ARGM-PRP
B-R-ARG5
B-C-ARGM-MOD
I-C-ARGM-MOD
B-C-ARGM-ADJ
I-C-ARGM-ADJ
B-C-ARGM-DIS
B-C-ARGM-NEG
B-C-ARGM-COM
B-C-ARGM-DIR
B-R-ARGM-MOD

View File

@ -99,7 +99,7 @@ if is_torch_available():
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
RobertaForSequenceClassification, RobertaForMultipleChoice,
RobertaForTokenClassification,
RobertaForTokenClassification, RobertaForQuestionAnswering,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,

View File

@ -377,7 +377,8 @@ def compute_predictions_logits(
output_null_log_odds_file,
verbose_logging,
version_2_with_negative,
null_score_diff_threshold
null_score_diff_threshold,
tokenizer,
):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file))
@ -474,11 +475,14 @@ def compute_predictions_logits(
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# tok_text = " ".join(tok_tokens)
#
# # De-tokenize WordPieces that have been split off.
# tok_text = tok_text.replace(" ##", "")
# tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()

View File

@ -4,6 +4,9 @@ import logging
import os
import json
import numpy as np
from multiprocessing import Pool
from multiprocessing import cpu_count
from functools import partial
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
from .utils import DataProcessor, InputExample, InputFeatures
@ -79,10 +82,168 @@ def _is_whitespace(c):
return True
return False
def squad_convert_example_to_features(example, max_seq_length,
doc_stride, max_query_length, is_training):
features = []
if is_training and not example.is_impossible:
# Get start and end position
start_position = example.start_position
end_position = example.end_position
def squad_convert_examples_to_features(
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False
):
# If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
return []
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
)
spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
span_doc_tokens = all_doc_tokens
while len(spans) * doc_stride < len(all_doc_tokens):
encoded_dict = tokenizer.encode_plus(
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
max_length=max_seq_length,
return_overflowing_tokens=True,
pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
)
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
if tokenizer.pad_token_id in encoded_dict['input_ids']:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
else:
non_padded_ids = encoded_dict['input_ids']
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
token_to_orig_map = {}
for i in range(paragraph_len):
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
encoded_dict["paragraph_len"] = paragraph_len
encoded_dict["tokens"] = tokens
encoded_dict["token_to_orig_map"] = token_to_orig_map
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
encoded_dict["token_is_max_context"] = {}
encoded_dict["start"] = len(spans) * doc_stride
encoded_dict["length"] = paragraph_len
spans.append(encoded_dict)
if "overflowing_tokens" not in encoded_dict:
break
span_doc_tokens = encoded_dict["overflowing_tokens"]
for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][
"truncated_query_with_special_tokens_length"] + j
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans:
# Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids'])
p_mask = np.minimum(p_mask, 1)
if tokenizer.padding_side == "right":
# Limit positive values to one
p_mask = 1 - p_mask
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
# Set the CLS index to '0'
p_mask[cls_index] = 0
span_is_impossible = example.is_impossible
start_position = 0
end_position = 0
if is_training and not span_is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = span["start"]
doc_end = span["start"] + span["length"] - 1
out_of_span = False
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = cls_index
end_position = cls_index
span_is_impossible = True
else:
if tokenizer.padding_side == "left":
doc_offset = 0
else:
doc_offset = len(truncated_query) + sequence_added_tokens
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
features.append(SquadFeatures(
span['input_ids'],
span['attention_mask'],
span['token_type_ids'],
cls_index,
p_mask.tolist(),
example_index=0,
unique_id=0,
paragraph_len=span['paragraph_len'],
token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position,
end_position=end_position
))
return features
def squad_convert_example_to_features_init(tokenizer_for_convert):
global tokenizer
tokenizer = tokenizer_for_convert
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training,
return_dataset=False, threads=1):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
@ -97,6 +258,8 @@ def squad_convert_examples_to_features(
return_dataset: Default False. Either 'pt' or 'tf'.
if 'pt': returns a torch.data.TensorDataset,
if 'tf': returns a tf.data.Dataset
threads: multiple processing threadsa-smi
Returns:
list of :class:`~transformers.data.processors.squad.SquadFeatures`
@ -116,172 +279,28 @@ def squad_convert_examples_to_features(
)
"""
# Defining helper methods
unique_id = 1000000000
# Defining helper methods
features = []
for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")):
if is_training and not example.is_impossible:
# Get start and end position
start_position = example.start_position
end_position = example.end_position
# If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
continue
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
)
spans = []
truncated_query = tokenizer.encode(
example.question_text, add_special_tokens=False, max_length=max_query_length
)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
span_doc_tokens = all_doc_tokens
while len(spans) * doc_stride < len(all_doc_tokens):
encoded_dict = tokenizer.encode_plus(
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
max_length=max_seq_length,
return_overflowing_tokens=True,
pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
)
paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else:
non_padded_ids = encoded_dict["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
token_to_orig_map = {}
for i in range(paragraph_len):
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
encoded_dict["paragraph_len"] = paragraph_len
encoded_dict["tokens"] = tokens
encoded_dict["token_to_orig_map"] = token_to_orig_map
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
encoded_dict["token_is_max_context"] = {}
encoded_dict["start"] = len(spans) * doc_stride
encoded_dict["length"] = paragraph_len
spans.append(encoded_dict)
if "overflowing_tokens" not in encoded_dict:
break
span_doc_tokens = encoded_dict["overflowing_tokens"]
for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = (
j
if tokenizer.padding_side == "left"
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
)
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans:
# Identify the position of the CLS token
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1)
if tokenizer.padding_side == "right":
# Limit positive values to one
p_mask = 1 - p_mask
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
# Set the CLS index to '0'
p_mask[cls_index] = 0
span_is_impossible = example.is_impossible
start_position = 0
end_position = 0
if is_training and not span_is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = span["start"]
doc_end = span["start"] + span["length"] - 1
out_of_span = False
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = cls_index
end_position = cls_index
span_is_impossible = True
else:
if tokenizer.padding_side == "left":
doc_offset = 0
else:
doc_offset = len(truncated_query) + sequence_added_tokens
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
features.append(
SquadFeatures(
span["input_ids"],
span["attention_mask"],
span["token_type_ids"],
cls_index,
p_mask.tolist(),
example_index=example_index,
unique_id=unique_id,
paragraph_len=span["paragraph_len"],
token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position,
end_position=end_position,
)
)
threads = min(threads, cpu_count())
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length,
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training)
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features'))
new_features = []
unique_id = 1000000000
example_index = 0
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'):
if not example_features:
continue
for example_feature in example_features:
example_feature.example_index = example_index
example_feature.unique_id = unique_id
new_features.append(example_feature)
unique_id += 1
if return_dataset == "pt":
example_index += 1
features = new_features
del new_features
if return_dataset == 'pt':
if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.")

View File

@ -555,3 +555,89 @@ class RobertaClassificationHead(nn.Module):
x = self.dropout(x)
x = self.out_proj(x)
return x
@add_start_docstrings("""Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForQuestionAnswering(BertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
start_positions=None, end_positions=None):
outputs = self.roberta(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)