mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Merge branch 'refs/heads/squad_roberta'
# Conflicts: # transformers/data/processors/squad.py
This commit is contained in:
commit
c7780700f5
@ -39,6 +39,7 @@ from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]
|
||||
'end_positions': batch[4],
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
||||
|
||||
@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
|
||||
example_indices = batch[3]
|
||||
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
|
||||
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
@ -363,7 +360,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
return_dataset='pt'
|
||||
return_dataset='pt',
|
||||
threads=args.threads,
|
||||
)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
@ -481,6 +479,8 @@ def main():
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
|
||||
parser.add_argument('--threads', type=int, default=1, help='multiple threads for converting example to features')
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
|
130
srl_label.txt
Normal file
130
srl_label.txt
Normal file
@ -0,0 +1,130 @@
|
||||
O
|
||||
I-ARG1
|
||||
I-ARG2
|
||||
I-ARG0
|
||||
B-V
|
||||
B-ARG1
|
||||
B-ARG0
|
||||
I-ARGM-ADV
|
||||
I-ARGM-TMP
|
||||
B-ARG2
|
||||
I-ARGM-LOC
|
||||
I-ARGM-MNR
|
||||
B-ARGM-TMP
|
||||
I-ARGM-CAU
|
||||
I-ARGM-PRP
|
||||
B-ARGM-MOD
|
||||
I-C-ARG1
|
||||
B-ARGM-ADV
|
||||
I-ARGM-PRD
|
||||
B-ARGM-DIS
|
||||
I-ARG3
|
||||
I-V
|
||||
I-ARG4
|
||||
B-ARGM-MNR
|
||||
B-ARGM-LOC
|
||||
I-ARGM-NEG
|
||||
B-ARGM-NEG
|
||||
B-R-ARG0
|
||||
I-ARGM-DIR
|
||||
I-ARGM-DIS
|
||||
I-ARGM-PNC
|
||||
I-ARGM-ADJ
|
||||
B-R-ARG1
|
||||
B-ARG3
|
||||
B-ARGM-PRP
|
||||
B-ARG4
|
||||
I-ARGM-GOL
|
||||
I-R-ARG0
|
||||
B-ARGM-CAU
|
||||
B-ARGM-DIR
|
||||
B-ARGM-PRD
|
||||
I-ARGM-EXT
|
||||
B-C-ARG1
|
||||
B-ARGM-ADJ
|
||||
I-C-ARG0
|
||||
B-ARGM-EXT
|
||||
I-C-ARG2
|
||||
I-ARGM-COM
|
||||
I-R-ARG1
|
||||
I-ARGM-MOD
|
||||
B-ARGM-GOL
|
||||
B-ARGM-PNC
|
||||
B-R-ARGM-LOC
|
||||
B-R-ARGM-TMP
|
||||
B-ARGM-LVB
|
||||
B-ARGM-COM
|
||||
B-R-ARG2
|
||||
I-C-ARGM-MNR
|
||||
B-C-ARG0
|
||||
I-R-ARGM-LOC
|
||||
B-C-ARG2
|
||||
I-C-ARGM-EXT
|
||||
I-C-ARG4
|
||||
B-ARGM-REC
|
||||
I-R-ARG2
|
||||
I-C-ARGM-TMP
|
||||
I-ARG5
|
||||
I-C-ARG3
|
||||
I-C-ARGM-ADV
|
||||
B-ARG5
|
||||
B-R-ARGM-MNR
|
||||
I-ARGM-DSP
|
||||
I-C-ARGM-LOC
|
||||
B-R-ARG3
|
||||
I-ARGA
|
||||
I-R-ARGM-MNR
|
||||
B-R-ARGM-CAU
|
||||
I-R-ARGM-TMP
|
||||
B-C-ARGM-MNR
|
||||
B-ARGA
|
||||
I-C-ARGM-DSP
|
||||
B-C-ARGM-ADV
|
||||
I-R-ARG3
|
||||
B-R-ARGM-ADV
|
||||
B-C-ARG4
|
||||
I-C-ARGM-CAU
|
||||
B-C-ARGM-EXT
|
||||
B-C-ARGM-TMP
|
||||
B-R-ARGM-DIR
|
||||
B-R-ARG4
|
||||
I-R-ARGM-ADV
|
||||
I-ARGM-REC
|
||||
B-C-ARG3
|
||||
B-C-ARGM-LOC
|
||||
B-R-ARGM-EXT
|
||||
B-ARGM-PRR
|
||||
B-R-ARGM-PRP
|
||||
B-ARGM-PRX
|
||||
I-R-ARGM-DIR
|
||||
I-R-ARGM-EXT
|
||||
I-C-ARGM-NEG
|
||||
B-ARGM-DSP
|
||||
B-R-ARGM-GOL
|
||||
I-R-ARGM-GOL
|
||||
I-R-ARGM-PNC
|
||||
I-C-ARGM-PRP
|
||||
B-R-ARGM-COM
|
||||
I-R-ARGM-PRP
|
||||
I-C-ARGM-COM
|
||||
B-C-ARGM-CAU
|
||||
B-C-ARGM-DSP
|
||||
I-R-ARGM-COM
|
||||
I-R-ARGM-CAU
|
||||
B-R-ARGM-PNC
|
||||
I-C-ARGM-DIS
|
||||
I-C-ARGM-DIR
|
||||
I-R-ARG4
|
||||
B-R-ARGM-PRD
|
||||
I-R-ARGM-PRD
|
||||
B-C-ARGM-PRP
|
||||
B-R-ARG5
|
||||
B-C-ARGM-MOD
|
||||
I-C-ARGM-MOD
|
||||
B-C-ARGM-ADJ
|
||||
I-C-ARGM-ADJ
|
||||
B-C-ARGM-DIS
|
||||
B-C-ARGM-NEG
|
||||
B-C-ARGM-COM
|
||||
B-C-ARGM-DIR
|
||||
B-R-ARGM-MOD
|
@ -99,7 +99,7 @@ if is_torch_available():
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
||||
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
||||
RobertaForTokenClassification,
|
||||
RobertaForTokenClassification, RobertaForQuestionAnswering,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
|
@ -377,7 +377,8 @@ def compute_predictions_logits(
|
||||
output_null_log_odds_file,
|
||||
verbose_logging,
|
||||
version_2_with_negative,
|
||||
null_score_diff_threshold
|
||||
null_score_diff_threshold,
|
||||
tokenizer,
|
||||
):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||
@ -474,11 +475,14 @@ def compute_predictions_logits(
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
|
||||
# tok_text = " ".join(tok_tokens)
|
||||
#
|
||||
# # De-tokenize WordPieces that have been split off.
|
||||
# tok_text = tok_text.replace(" ##", "")
|
||||
# tok_text = tok_text.replace("##", "")
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
|
@ -4,6 +4,9 @@ import logging
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing import cpu_count
|
||||
from functools import partial
|
||||
|
||||
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
from .utils import DataProcessor, InputExample, InputFeatures
|
||||
@ -79,48 +82,9 @@ def _is_whitespace(c):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def squad_convert_examples_to_features(
|
||||
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False
|
||||
):
|
||||
"""
|
||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||
|
||||
Args:
|
||||
examples: list of :class:`~transformers.data.processors.squad.SquadExample`
|
||||
tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer`
|
||||
max_seq_length: The maximum sequence length of the inputs.
|
||||
doc_stride: The stride used when the context is too large and is split across several features.
|
||||
max_query_length: The maximum length of the query.
|
||||
is_training: whether to create features for model evaluation or model training.
|
||||
return_dataset: Default False. Either 'pt' or 'tf'.
|
||||
if 'pt': returns a torch.data.TensorDataset,
|
||||
if 'tf': returns a tf.data.Dataset
|
||||
|
||||
Returns:
|
||||
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
||||
|
||||
Example::
|
||||
|
||||
processor = SquadV2Processor()
|
||||
examples = processor.get_dev_examples(data_dir)
|
||||
|
||||
features = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
)
|
||||
"""
|
||||
|
||||
# Defining helper methods
|
||||
unique_id = 1000000000
|
||||
|
||||
def squad_convert_example_to_features(example, max_seq_length,
|
||||
doc_stride, max_query_length, is_training):
|
||||
features = []
|
||||
for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")):
|
||||
if is_training and not example.is_impossible:
|
||||
# Get start and end position
|
||||
start_position = example.start_position
|
||||
@ -131,7 +95,7 @@ def squad_convert_examples_to_features(
|
||||
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||
continue
|
||||
return []
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
@ -156,10 +120,9 @@ def squad_convert_examples_to_features(
|
||||
|
||||
spans = []
|
||||
|
||||
truncated_query = tokenizer.encode(
|
||||
example.question_text, add_special_tokens=False, max_length=max_query_length
|
||||
)
|
||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
|
||||
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||
|
||||
span_doc_tokens = all_doc_tokens
|
||||
@ -172,18 +135,16 @@ def squad_convert_examples_to_features(
|
||||
return_overflowing_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
|
||||
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
|
||||
)
|
||||
|
||||
paragraph_len = min(
|
||||
len(all_doc_tokens) - len(spans) * doc_stride,
|
||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
||||
)
|
||||
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride,
|
||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
||||
|
||||
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
||||
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
||||
if tokenizer.pad_token_id in encoded_dict['input_ids']:
|
||||
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
||||
else:
|
||||
non_padded_ids = encoded_dict["input_ids"]
|
||||
non_padded_ids = encoded_dict['input_ids']
|
||||
|
||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||
|
||||
@ -209,20 +170,17 @@ def squad_convert_examples_to_features(
|
||||
for doc_span_index in range(len(spans)):
|
||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||
index = (
|
||||
j
|
||||
if tokenizer.padding_side == "left"
|
||||
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
)
|
||||
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][
|
||||
"truncated_query_with_special_tokens_length"] + j
|
||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||
|
||||
for span in spans:
|
||||
# Identify the position of the CLS token
|
||||
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
||||
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
|
||||
|
||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = np.array(span["token_type_ids"])
|
||||
p_mask = np.array(span['token_type_ids'])
|
||||
|
||||
p_mask = np.minimum(p_mask, 1)
|
||||
|
||||
@ -261,27 +219,88 @@ def squad_convert_examples_to_features(
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
features.append(
|
||||
SquadFeatures(
|
||||
span["input_ids"],
|
||||
span["attention_mask"],
|
||||
span["token_type_ids"],
|
||||
features.append(SquadFeatures(
|
||||
span['input_ids'],
|
||||
span['attention_mask'],
|
||||
span['token_type_ids'],
|
||||
cls_index,
|
||||
p_mask.tolist(),
|
||||
example_index=example_index,
|
||||
unique_id=unique_id,
|
||||
paragraph_len=span["paragraph_len"],
|
||||
example_index=0,
|
||||
unique_id=0,
|
||||
paragraph_len=span['paragraph_len'],
|
||||
token_is_max_context=span["token_is_max_context"],
|
||||
tokens=span["tokens"],
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
)
|
||||
)
|
||||
end_position=end_position
|
||||
))
|
||||
return features
|
||||
|
||||
def squad_convert_example_to_features_init(tokenizer_for_convert):
|
||||
global tokenizer
|
||||
tokenizer = tokenizer_for_convert
|
||||
|
||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training,
|
||||
return_dataset=False, threads=1):
|
||||
"""
|
||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||
|
||||
Args:
|
||||
examples: list of :class:`~transformers.data.processors.squad.SquadExample`
|
||||
tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer`
|
||||
max_seq_length: The maximum sequence length of the inputs.
|
||||
doc_stride: The stride used when the context is too large and is split across several features.
|
||||
max_query_length: The maximum length of the query.
|
||||
is_training: whether to create features for model evaluation or model training.
|
||||
return_dataset: Default False. Either 'pt' or 'tf'.
|
||||
if 'pt': returns a torch.data.TensorDataset,
|
||||
if 'tf': returns a tf.data.Dataset
|
||||
threads: multiple processing threadsa-smi
|
||||
|
||||
|
||||
Returns:
|
||||
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
||||
|
||||
Example::
|
||||
|
||||
processor = SquadV2Processor()
|
||||
examples = processor.get_dev_examples(data_dir)
|
||||
|
||||
features = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
)
|
||||
"""
|
||||
|
||||
# Defining helper methods
|
||||
features = []
|
||||
threads = min(threads, cpu_count())
|
||||
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
||||
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length,
|
||||
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training)
|
||||
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features'))
|
||||
new_features = []
|
||||
unique_id = 1000000000
|
||||
example_index = 0
|
||||
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'):
|
||||
if not example_features:
|
||||
continue
|
||||
for example_feature in example_features:
|
||||
example_feature.example_index = example_index
|
||||
example_feature.unique_id = unique_id
|
||||
new_features.append(example_feature)
|
||||
unique_id += 1
|
||||
|
||||
if return_dataset == "pt":
|
||||
example_index += 1
|
||||
features = new_features
|
||||
del new_features
|
||||
if return_dataset == 'pt':
|
||||
if not is_torch_available():
|
||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||
|
||||
|
@ -555,3 +555,89 @@ class RobertaClassificationHead(nn.Module):
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings("""Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
class RobertaForQuestionAnswering(BertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
Examples::
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
start_positions = torch.tensor([1])
|
||||
end_positions = torch.tensor([3])
|
||||
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||
loss, start_scores, end_scores = outputs[:2]
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
start_positions=None, end_positions=None):
|
||||
|
||||
outputs = self.roberta(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
Loading…
Reference in New Issue
Block a user