mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Merge branch 'refs/heads/squad_roberta'
# Conflicts: # transformers/data/processors/squad.py
This commit is contained in:
commit
c7780700f5
@ -39,6 +39,7 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||||
BertForQuestionAnswering, BertTokenizer,
|
BertForQuestionAnswering, BertTokenizer,
|
||||||
|
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
|
||||||
XLMConfig, XLMForQuestionAnswering,
|
XLMConfig, XLMForQuestionAnswering,
|
||||||
XLMTokenizer, XLNetConfig,
|
XLMTokenizer, XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||||
|
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||||
@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': batch[0],
|
'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
|
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||||
'start_positions': batch[3],
|
'start_positions': batch[3],
|
||||||
'end_positions': batch[4]
|
'end_positions': batch[4],
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model_type != 'distilbert':
|
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
|
||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
||||||
|
|
||||||
@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': batch[0],
|
'input_ids': batch[0],
|
||||||
'attention_mask': batch[1]
|
'attention_mask': batch[1],
|
||||||
|
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model_type != 'distilbert':
|
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
|
||||||
|
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
|
|
||||||
# XLNet and XLM use more arguments for their predictions
|
# XLNet and XLM use more arguments for their predictions
|
||||||
@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
|
||||||
|
|
||||||
# Compute the F1 and exact scores.
|
# Compute the F1 and exact scores.
|
||||||
results = squad_evaluate(examples, predictions)
|
results = squad_evaluate(examples, predictions)
|
||||||
@ -363,7 +360,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate,
|
is_training=not evaluate,
|
||||||
return_dataset='pt'
|
return_dataset='pt',
|
||||||
|
threads=args.threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@ -481,6 +479,8 @@ def main():
|
|||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||||
|
|
||||||
|
parser.add_argument('--threads', type=int, default=1, help='multiple threads for converting example to features')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||||
|
130
srl_label.txt
Normal file
130
srl_label.txt
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
O
|
||||||
|
I-ARG1
|
||||||
|
I-ARG2
|
||||||
|
I-ARG0
|
||||||
|
B-V
|
||||||
|
B-ARG1
|
||||||
|
B-ARG0
|
||||||
|
I-ARGM-ADV
|
||||||
|
I-ARGM-TMP
|
||||||
|
B-ARG2
|
||||||
|
I-ARGM-LOC
|
||||||
|
I-ARGM-MNR
|
||||||
|
B-ARGM-TMP
|
||||||
|
I-ARGM-CAU
|
||||||
|
I-ARGM-PRP
|
||||||
|
B-ARGM-MOD
|
||||||
|
I-C-ARG1
|
||||||
|
B-ARGM-ADV
|
||||||
|
I-ARGM-PRD
|
||||||
|
B-ARGM-DIS
|
||||||
|
I-ARG3
|
||||||
|
I-V
|
||||||
|
I-ARG4
|
||||||
|
B-ARGM-MNR
|
||||||
|
B-ARGM-LOC
|
||||||
|
I-ARGM-NEG
|
||||||
|
B-ARGM-NEG
|
||||||
|
B-R-ARG0
|
||||||
|
I-ARGM-DIR
|
||||||
|
I-ARGM-DIS
|
||||||
|
I-ARGM-PNC
|
||||||
|
I-ARGM-ADJ
|
||||||
|
B-R-ARG1
|
||||||
|
B-ARG3
|
||||||
|
B-ARGM-PRP
|
||||||
|
B-ARG4
|
||||||
|
I-ARGM-GOL
|
||||||
|
I-R-ARG0
|
||||||
|
B-ARGM-CAU
|
||||||
|
B-ARGM-DIR
|
||||||
|
B-ARGM-PRD
|
||||||
|
I-ARGM-EXT
|
||||||
|
B-C-ARG1
|
||||||
|
B-ARGM-ADJ
|
||||||
|
I-C-ARG0
|
||||||
|
B-ARGM-EXT
|
||||||
|
I-C-ARG2
|
||||||
|
I-ARGM-COM
|
||||||
|
I-R-ARG1
|
||||||
|
I-ARGM-MOD
|
||||||
|
B-ARGM-GOL
|
||||||
|
B-ARGM-PNC
|
||||||
|
B-R-ARGM-LOC
|
||||||
|
B-R-ARGM-TMP
|
||||||
|
B-ARGM-LVB
|
||||||
|
B-ARGM-COM
|
||||||
|
B-R-ARG2
|
||||||
|
I-C-ARGM-MNR
|
||||||
|
B-C-ARG0
|
||||||
|
I-R-ARGM-LOC
|
||||||
|
B-C-ARG2
|
||||||
|
I-C-ARGM-EXT
|
||||||
|
I-C-ARG4
|
||||||
|
B-ARGM-REC
|
||||||
|
I-R-ARG2
|
||||||
|
I-C-ARGM-TMP
|
||||||
|
I-ARG5
|
||||||
|
I-C-ARG3
|
||||||
|
I-C-ARGM-ADV
|
||||||
|
B-ARG5
|
||||||
|
B-R-ARGM-MNR
|
||||||
|
I-ARGM-DSP
|
||||||
|
I-C-ARGM-LOC
|
||||||
|
B-R-ARG3
|
||||||
|
I-ARGA
|
||||||
|
I-R-ARGM-MNR
|
||||||
|
B-R-ARGM-CAU
|
||||||
|
I-R-ARGM-TMP
|
||||||
|
B-C-ARGM-MNR
|
||||||
|
B-ARGA
|
||||||
|
I-C-ARGM-DSP
|
||||||
|
B-C-ARGM-ADV
|
||||||
|
I-R-ARG3
|
||||||
|
B-R-ARGM-ADV
|
||||||
|
B-C-ARG4
|
||||||
|
I-C-ARGM-CAU
|
||||||
|
B-C-ARGM-EXT
|
||||||
|
B-C-ARGM-TMP
|
||||||
|
B-R-ARGM-DIR
|
||||||
|
B-R-ARG4
|
||||||
|
I-R-ARGM-ADV
|
||||||
|
I-ARGM-REC
|
||||||
|
B-C-ARG3
|
||||||
|
B-C-ARGM-LOC
|
||||||
|
B-R-ARGM-EXT
|
||||||
|
B-ARGM-PRR
|
||||||
|
B-R-ARGM-PRP
|
||||||
|
B-ARGM-PRX
|
||||||
|
I-R-ARGM-DIR
|
||||||
|
I-R-ARGM-EXT
|
||||||
|
I-C-ARGM-NEG
|
||||||
|
B-ARGM-DSP
|
||||||
|
B-R-ARGM-GOL
|
||||||
|
I-R-ARGM-GOL
|
||||||
|
I-R-ARGM-PNC
|
||||||
|
I-C-ARGM-PRP
|
||||||
|
B-R-ARGM-COM
|
||||||
|
I-R-ARGM-PRP
|
||||||
|
I-C-ARGM-COM
|
||||||
|
B-C-ARGM-CAU
|
||||||
|
B-C-ARGM-DSP
|
||||||
|
I-R-ARGM-COM
|
||||||
|
I-R-ARGM-CAU
|
||||||
|
B-R-ARGM-PNC
|
||||||
|
I-C-ARGM-DIS
|
||||||
|
I-C-ARGM-DIR
|
||||||
|
I-R-ARG4
|
||||||
|
B-R-ARGM-PRD
|
||||||
|
I-R-ARGM-PRD
|
||||||
|
B-C-ARGM-PRP
|
||||||
|
B-R-ARG5
|
||||||
|
B-C-ARGM-MOD
|
||||||
|
I-C-ARGM-MOD
|
||||||
|
B-C-ARGM-ADJ
|
||||||
|
I-C-ARGM-ADJ
|
||||||
|
B-C-ARGM-DIS
|
||||||
|
B-C-ARGM-NEG
|
||||||
|
B-C-ARGM-COM
|
||||||
|
B-C-ARGM-DIR
|
||||||
|
B-R-ARGM-MOD
|
@ -99,7 +99,7 @@ if is_torch_available():
|
|||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
||||||
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification, RobertaForQuestionAnswering,
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
||||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||||
|
@ -377,7 +377,8 @@ def compute_predictions_logits(
|
|||||||
output_null_log_odds_file,
|
output_null_log_odds_file,
|
||||||
verbose_logging,
|
verbose_logging,
|
||||||
version_2_with_negative,
|
version_2_with_negative,
|
||||||
null_score_diff_threshold
|
null_score_diff_threshold,
|
||||||
|
tokenizer,
|
||||||
):
|
):
|
||||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||||
@ -474,11 +475,14 @@ def compute_predictions_logits(
|
|||||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||||
tok_text = " ".join(tok_tokens)
|
|
||||||
|
|
||||||
# De-tokenize WordPieces that have been split off.
|
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||||
tok_text = tok_text.replace(" ##", "")
|
|
||||||
tok_text = tok_text.replace("##", "")
|
# tok_text = " ".join(tok_tokens)
|
||||||
|
#
|
||||||
|
# # De-tokenize WordPieces that have been split off.
|
||||||
|
# tok_text = tok_text.replace(" ##", "")
|
||||||
|
# tok_text = tok_text.replace("##", "")
|
||||||
|
|
||||||
# Clean whitespace
|
# Clean whitespace
|
||||||
tok_text = tok_text.strip()
|
tok_text = tok_text.strip()
|
||||||
|
@ -4,6 +4,9 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
from .utils import DataProcessor, InputExample, InputFeatures
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
@ -79,10 +82,168 @@ def _is_whitespace(c):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def squad_convert_example_to_features(example, max_seq_length,
|
||||||
|
doc_stride, max_query_length, is_training):
|
||||||
|
features = []
|
||||||
|
if is_training and not example.is_impossible:
|
||||||
|
# Get start and end position
|
||||||
|
start_position = example.start_position
|
||||||
|
end_position = example.end_position
|
||||||
|
|
||||||
def squad_convert_examples_to_features(
|
# If the answer cannot be found in the text, then skip this example.
|
||||||
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False
|
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)])
|
||||||
):
|
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
||||||
|
if actual_text.find(cleaned_answer_text) == -1:
|
||||||
|
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||||
|
return []
|
||||||
|
|
||||||
|
tok_to_orig_index = []
|
||||||
|
orig_to_tok_index = []
|
||||||
|
all_doc_tokens = []
|
||||||
|
for (i, token) in enumerate(example.doc_tokens):
|
||||||
|
orig_to_tok_index.append(len(all_doc_tokens))
|
||||||
|
sub_tokens = tokenizer.tokenize(token)
|
||||||
|
for sub_token in sub_tokens:
|
||||||
|
tok_to_orig_index.append(i)
|
||||||
|
all_doc_tokens.append(sub_token)
|
||||||
|
|
||||||
|
if is_training and not example.is_impossible:
|
||||||
|
tok_start_position = orig_to_tok_index[example.start_position]
|
||||||
|
if example.end_position < len(example.doc_tokens) - 1:
|
||||||
|
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||||
|
else:
|
||||||
|
tok_end_position = len(all_doc_tokens) - 1
|
||||||
|
|
||||||
|
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||||
|
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
|
||||||
|
)
|
||||||
|
|
||||||
|
spans = []
|
||||||
|
|
||||||
|
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||||
|
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
|
||||||
|
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||||
|
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||||
|
|
||||||
|
span_doc_tokens = all_doc_tokens
|
||||||
|
while len(spans) * doc_stride < len(all_doc_tokens):
|
||||||
|
|
||||||
|
encoded_dict = tokenizer.encode_plus(
|
||||||
|
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
|
||||||
|
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
|
||||||
|
max_length=max_seq_length,
|
||||||
|
return_overflowing_tokens=True,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||||
|
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
|
||||||
|
)
|
||||||
|
|
||||||
|
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride,
|
||||||
|
max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id in encoded_dict['input_ids']:
|
||||||
|
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
||||||
|
else:
|
||||||
|
non_padded_ids = encoded_dict['input_ids']
|
||||||
|
|
||||||
|
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||||
|
|
||||||
|
token_to_orig_map = {}
|
||||||
|
for i in range(paragraph_len):
|
||||||
|
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
||||||
|
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
||||||
|
|
||||||
|
encoded_dict["paragraph_len"] = paragraph_len
|
||||||
|
encoded_dict["tokens"] = tokens
|
||||||
|
encoded_dict["token_to_orig_map"] = token_to_orig_map
|
||||||
|
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
|
||||||
|
encoded_dict["token_is_max_context"] = {}
|
||||||
|
encoded_dict["start"] = len(spans) * doc_stride
|
||||||
|
encoded_dict["length"] = paragraph_len
|
||||||
|
|
||||||
|
spans.append(encoded_dict)
|
||||||
|
|
||||||
|
if "overflowing_tokens" not in encoded_dict:
|
||||||
|
break
|
||||||
|
span_doc_tokens = encoded_dict["overflowing_tokens"]
|
||||||
|
|
||||||
|
for doc_span_index in range(len(spans)):
|
||||||
|
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||||
|
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||||
|
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][
|
||||||
|
"truncated_query_with_special_tokens_length"] + j
|
||||||
|
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||||
|
|
||||||
|
for span in spans:
|
||||||
|
# Identify the position of the CLS token
|
||||||
|
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
|
||||||
|
|
||||||
|
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||||
|
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||||
|
p_mask = np.array(span['token_type_ids'])
|
||||||
|
|
||||||
|
p_mask = np.minimum(p_mask, 1)
|
||||||
|
|
||||||
|
if tokenizer.padding_side == "right":
|
||||||
|
# Limit positive values to one
|
||||||
|
p_mask = 1 - p_mask
|
||||||
|
|
||||||
|
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
|
||||||
|
|
||||||
|
# Set the CLS index to '0'
|
||||||
|
p_mask[cls_index] = 0
|
||||||
|
|
||||||
|
span_is_impossible = example.is_impossible
|
||||||
|
start_position = 0
|
||||||
|
end_position = 0
|
||||||
|
if is_training and not span_is_impossible:
|
||||||
|
# For training, if our document chunk does not contain an annotation
|
||||||
|
# we throw it out, since there is nothing to predict.
|
||||||
|
doc_start = span["start"]
|
||||||
|
doc_end = span["start"] + span["length"] - 1
|
||||||
|
out_of_span = False
|
||||||
|
|
||||||
|
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
||||||
|
out_of_span = True
|
||||||
|
|
||||||
|
if out_of_span:
|
||||||
|
start_position = cls_index
|
||||||
|
end_position = cls_index
|
||||||
|
span_is_impossible = True
|
||||||
|
else:
|
||||||
|
if tokenizer.padding_side == "left":
|
||||||
|
doc_offset = 0
|
||||||
|
else:
|
||||||
|
doc_offset = len(truncated_query) + sequence_added_tokens
|
||||||
|
|
||||||
|
start_position = tok_start_position - doc_start + doc_offset
|
||||||
|
end_position = tok_end_position - doc_start + doc_offset
|
||||||
|
|
||||||
|
features.append(SquadFeatures(
|
||||||
|
span['input_ids'],
|
||||||
|
span['attention_mask'],
|
||||||
|
span['token_type_ids'],
|
||||||
|
cls_index,
|
||||||
|
p_mask.tolist(),
|
||||||
|
example_index=0,
|
||||||
|
unique_id=0,
|
||||||
|
paragraph_len=span['paragraph_len'],
|
||||||
|
token_is_max_context=span["token_is_max_context"],
|
||||||
|
tokens=span["tokens"],
|
||||||
|
token_to_orig_map=span["token_to_orig_map"],
|
||||||
|
|
||||||
|
start_position=start_position,
|
||||||
|
end_position=end_position
|
||||||
|
))
|
||||||
|
return features
|
||||||
|
|
||||||
|
def squad_convert_example_to_features_init(tokenizer_for_convert):
|
||||||
|
global tokenizer
|
||||||
|
tokenizer = tokenizer_for_convert
|
||||||
|
|
||||||
|
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||||
|
doc_stride, max_query_length, is_training,
|
||||||
|
return_dataset=False, threads=1):
|
||||||
"""
|
"""
|
||||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||||
@ -97,6 +258,8 @@ def squad_convert_examples_to_features(
|
|||||||
return_dataset: Default False. Either 'pt' or 'tf'.
|
return_dataset: Default False. Either 'pt' or 'tf'.
|
||||||
if 'pt': returns a torch.data.TensorDataset,
|
if 'pt': returns a torch.data.TensorDataset,
|
||||||
if 'tf': returns a tf.data.Dataset
|
if 'tf': returns a tf.data.Dataset
|
||||||
|
threads: multiple processing threadsa-smi
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
||||||
@ -116,172 +279,28 @@ def squad_convert_examples_to_features(
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Defining helper methods
|
# Defining helper methods
|
||||||
unique_id = 1000000000
|
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")):
|
threads = min(threads, cpu_count())
|
||||||
if is_training and not example.is_impossible:
|
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
||||||
# Get start and end position
|
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length,
|
||||||
start_position = example.start_position
|
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training)
|
||||||
end_position = example.end_position
|
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features'))
|
||||||
|
new_features = []
|
||||||
# If the answer cannot be found in the text, then skip this example.
|
unique_id = 1000000000
|
||||||
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
|
example_index = 0
|
||||||
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'):
|
||||||
if actual_text.find(cleaned_answer_text) == -1:
|
if not example_features:
|
||||||
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
continue
|
||||||
continue
|
for example_feature in example_features:
|
||||||
|
example_feature.example_index = example_index
|
||||||
tok_to_orig_index = []
|
example_feature.unique_id = unique_id
|
||||||
orig_to_tok_index = []
|
new_features.append(example_feature)
|
||||||
all_doc_tokens = []
|
|
||||||
for (i, token) in enumerate(example.doc_tokens):
|
|
||||||
orig_to_tok_index.append(len(all_doc_tokens))
|
|
||||||
sub_tokens = tokenizer.tokenize(token)
|
|
||||||
for sub_token in sub_tokens:
|
|
||||||
tok_to_orig_index.append(i)
|
|
||||||
all_doc_tokens.append(sub_token)
|
|
||||||
|
|
||||||
if is_training and not example.is_impossible:
|
|
||||||
tok_start_position = orig_to_tok_index[example.start_position]
|
|
||||||
if example.end_position < len(example.doc_tokens) - 1:
|
|
||||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
|
||||||
else:
|
|
||||||
tok_end_position = len(all_doc_tokens) - 1
|
|
||||||
|
|
||||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
|
||||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
|
|
||||||
)
|
|
||||||
|
|
||||||
spans = []
|
|
||||||
|
|
||||||
truncated_query = tokenizer.encode(
|
|
||||||
example.question_text, add_special_tokens=False, max_length=max_query_length
|
|
||||||
)
|
|
||||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
|
|
||||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
|
||||||
|
|
||||||
span_doc_tokens = all_doc_tokens
|
|
||||||
while len(spans) * doc_stride < len(all_doc_tokens):
|
|
||||||
|
|
||||||
encoded_dict = tokenizer.encode_plus(
|
|
||||||
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
|
|
||||||
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
|
|
||||||
max_length=max_seq_length,
|
|
||||||
return_overflowing_tokens=True,
|
|
||||||
pad_to_max_length=True,
|
|
||||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
|
||||||
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
|
|
||||||
)
|
|
||||||
|
|
||||||
paragraph_len = min(
|
|
||||||
len(all_doc_tokens) - len(spans) * doc_stride,
|
|
||||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
|
||||||
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
|
||||||
else:
|
|
||||||
non_padded_ids = encoded_dict["input_ids"]
|
|
||||||
|
|
||||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
|
||||||
|
|
||||||
token_to_orig_map = {}
|
|
||||||
for i in range(paragraph_len):
|
|
||||||
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
|
||||||
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
|
||||||
|
|
||||||
encoded_dict["paragraph_len"] = paragraph_len
|
|
||||||
encoded_dict["tokens"] = tokens
|
|
||||||
encoded_dict["token_to_orig_map"] = token_to_orig_map
|
|
||||||
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
|
|
||||||
encoded_dict["token_is_max_context"] = {}
|
|
||||||
encoded_dict["start"] = len(spans) * doc_stride
|
|
||||||
encoded_dict["length"] = paragraph_len
|
|
||||||
|
|
||||||
spans.append(encoded_dict)
|
|
||||||
|
|
||||||
if "overflowing_tokens" not in encoded_dict:
|
|
||||||
break
|
|
||||||
span_doc_tokens = encoded_dict["overflowing_tokens"]
|
|
||||||
|
|
||||||
for doc_span_index in range(len(spans)):
|
|
||||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
|
||||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
|
||||||
index = (
|
|
||||||
j
|
|
||||||
if tokenizer.padding_side == "left"
|
|
||||||
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
|
||||||
)
|
|
||||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
|
||||||
|
|
||||||
for span in spans:
|
|
||||||
# Identify the position of the CLS token
|
|
||||||
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
|
||||||
|
|
||||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
|
||||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
|
||||||
p_mask = np.array(span["token_type_ids"])
|
|
||||||
|
|
||||||
p_mask = np.minimum(p_mask, 1)
|
|
||||||
|
|
||||||
if tokenizer.padding_side == "right":
|
|
||||||
# Limit positive values to one
|
|
||||||
p_mask = 1 - p_mask
|
|
||||||
|
|
||||||
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
|
|
||||||
|
|
||||||
# Set the CLS index to '0'
|
|
||||||
p_mask[cls_index] = 0
|
|
||||||
|
|
||||||
span_is_impossible = example.is_impossible
|
|
||||||
start_position = 0
|
|
||||||
end_position = 0
|
|
||||||
if is_training and not span_is_impossible:
|
|
||||||
# For training, if our document chunk does not contain an annotation
|
|
||||||
# we throw it out, since there is nothing to predict.
|
|
||||||
doc_start = span["start"]
|
|
||||||
doc_end = span["start"] + span["length"] - 1
|
|
||||||
out_of_span = False
|
|
||||||
|
|
||||||
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
|
||||||
out_of_span = True
|
|
||||||
|
|
||||||
if out_of_span:
|
|
||||||
start_position = cls_index
|
|
||||||
end_position = cls_index
|
|
||||||
span_is_impossible = True
|
|
||||||
else:
|
|
||||||
if tokenizer.padding_side == "left":
|
|
||||||
doc_offset = 0
|
|
||||||
else:
|
|
||||||
doc_offset = len(truncated_query) + sequence_added_tokens
|
|
||||||
|
|
||||||
start_position = tok_start_position - doc_start + doc_offset
|
|
||||||
end_position = tok_end_position - doc_start + doc_offset
|
|
||||||
|
|
||||||
features.append(
|
|
||||||
SquadFeatures(
|
|
||||||
span["input_ids"],
|
|
||||||
span["attention_mask"],
|
|
||||||
span["token_type_ids"],
|
|
||||||
cls_index,
|
|
||||||
p_mask.tolist(),
|
|
||||||
example_index=example_index,
|
|
||||||
unique_id=unique_id,
|
|
||||||
paragraph_len=span["paragraph_len"],
|
|
||||||
token_is_max_context=span["token_is_max_context"],
|
|
||||||
tokens=span["tokens"],
|
|
||||||
token_to_orig_map=span["token_to_orig_map"],
|
|
||||||
start_position=start_position,
|
|
||||||
end_position=end_position,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
example_index += 1
|
||||||
if return_dataset == "pt":
|
features = new_features
|
||||||
|
del new_features
|
||||||
|
if return_dataset == 'pt':
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||||
|
|
||||||
|
@ -555,3 +555,89 @@ class RobertaClassificationHead(nn.Module):
|
|||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("""Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||||
|
class RobertaForQuestionAnswering(BertPreTrainedModel):
|
||||||
|
r"""
|
||||||
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
|
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||||
|
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||||
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
Examples::
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||||
|
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||||
|
start_positions = torch.tensor([1])
|
||||||
|
end_positions = torch.tensor([3])
|
||||||
|
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||||
|
loss, start_scores, end_scores = outputs[:2]
|
||||||
|
"""
|
||||||
|
config_class = RobertaConfig
|
||||||
|
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(RobertaForQuestionAnswering, self).__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.roberta = RobertaModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
start_positions=None, end_positions=None):
|
||||||
|
|
||||||
|
outputs = self.roberta(input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1)
|
||||||
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions.clamp_(0, ignored_index)
|
||||||
|
end_positions.clamp_(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
outputs = (total_loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
Loading…
Reference in New Issue
Block a user