SQuAD v2 BERT + XLNet

This commit is contained in:
Lysandre 2019-11-25 19:22:21 -05:00
parent e0e55bc550
commit 0669c1fcd1
4 changed files with 92 additions and 94 deletions

View File

@ -27,7 +27,7 @@ from .data import (is_sklearn_available,
glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels,
squad_convert_examples_to_features, SquadFeatures,
SquadExample, read_squad_examples)
SquadExample)
if is_sklearn_available():
from .data import glue_compute_metrics

View File

@ -1,6 +1,6 @@
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .processors import squad_convert_examples_to_features, SquadExample, read_squad_examples
from .processors import squad_convert_examples_to_features, SquadExample
from .metrics import is_sklearn_available
if is_sklearn_available():

View File

@ -1,4 +1,4 @@
from .utils import InputExample, InputFeatures, DataProcessor
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, read_squad_examples
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample

View File

@ -46,7 +46,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index
def _new_check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1:
@ -92,7 +91,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
features = []
new_features = []
for (example_index, example) in enumerate(tqdm(examples)):
if is_training:
if is_training and not example.is_impossible:
# Get start and end position
answer_length = len(example.answer_text)
start_position = example.start_position
@ -105,6 +104,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
continue
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
@ -115,6 +115,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
)
spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
@ -187,6 +199,34 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Set the CLS index to '0'
p_mask[cls_index] = 0
span_is_impossible = example.is_impossible
start_position = 0
end_position = 0
if is_training and not span_is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = span["start"]
doc_end = span["start"] + span["length"] - 1
out_of_span = False
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = cls_index
end_position = cls_index
span_is_impossible = True
else:
if sequence_a_is_doc:
doc_offset = 0
else:
doc_offset = len(truncated_query) + sequence_added_tokens
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
new_features.append(NewSquadFeatures(
span['input_ids'],
span['attention_mask'],
@ -199,7 +239,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
paragraph_len=span['paragraph_len'],
token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"]
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position,
end_position=end_position
))
unique_id += 1
@ -207,86 +250,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return new_features
def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)["data"]
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in paragraph_text:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
end_position = None
orig_answer_text = None
is_impossible = False
if is_training:
if version_2_with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer.")
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset + answer_length - 1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position,
is_impossible=is_impossible)
examples.append(example)
return examples
class SquadV1Processor(DataProcessor):
class SquadProcessor(DataProcessor):
"""Processor for the SQuAD data set."""
train_file = None
dev_file = None
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
@ -301,13 +268,19 @@ class SquadV1Processor(DataProcessor):
def get_train_examples(self, data_dir, only_first=None):
"""See base class."""
with open(os.path.join(data_dir, "train-v1.1.json"), "r", encoding='utf-8') as reader:
if self.train_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.train_file), "r", encoding='utf-8') as reader:
input_data = json.load(reader)["data"]
return self._create_examples(input_data, "train", only_first)
def get_dev_examples(self, data_dir, only_first=None):
"""See base class."""
with open(os.path.join(data_dir, "dev-v1.1.json"), "r", encoding='utf-8') as reader:
if self.dev_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.dev_file), "r", encoding='utf-8') as reader:
input_data = json.load(reader)["data"]
return self._create_examples(input_data, "dev", only_first)
@ -329,7 +302,13 @@ class SquadV1Processor(DataProcessor):
question_text = qa["question"]
start_position_character = None
answer_text = None
if is_training:
if "is_impossible" in qa:
is_impossible = qa["is_impossible"]
else:
is_impossible = False
if not is_impossible and is_training:
if (len(qa["answers"]) != 1):
raise ValueError(
"For training, each question should have exactly 1 answer.")
@ -343,15 +322,25 @@ class SquadV1Processor(DataProcessor):
context_text=context_text,
answer_text=answer_text,
start_position_character=start_position_character,
title=title
title=title,
is_impossible=is_impossible
)
examples.append(example)
if only_first is not None and len(examples) > only_first:
return examples
return examples
class SquadV1Processor(SquadProcessor):
train_file = "train-v1.1.json"
dev_file = "dev-v1.1.json"
class SquadV2Processor(SquadProcessor):
train_file = "train-v2.0.json"
dev_file = "dev-v2.0.json"
class NewSquadExample(object):
"""
@ -364,13 +353,16 @@ class NewSquadExample(object):
context_text,
answer_text,
start_position_character,
title):
title,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.context_text = context_text
self.answer_text = answer_text
self.title = title
self.is_impossible = False
self.is_impossible = is_impossible
self.start_position, self.end_position = 0, 0
doc_tokens = []
char_to_word_offset = []
@ -392,7 +384,7 @@ class NewSquadExample(object):
self.char_to_word_offset = char_to_word_offset
# Start end end positions only has a value during evaluation.
if start_position_character is not None:
if start_position_character is not None and not is_impossible:
self.start_position = char_to_word_offset[start_position_character]
self.end_position = char_to_word_offset[start_position_character + len(answer_text) - 1]
@ -415,7 +407,10 @@ class NewSquadFeatures(object):
paragraph_len,
token_is_max_context,
tokens,
token_to_orig_map
token_to_orig_map,
start_position,
end_position
):
self.input_ids = input_ids
self.attention_mask = attention_mask
@ -430,6 +425,9 @@ class NewSquadFeatures(object):
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.start_position = start_position
self.end_position = end_position
class SquadExample(object):
"""
A single training/test example for the Squad dataset.