mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
SQuAD v2 BERT + XLNet
This commit is contained in:
parent
e0e55bc550
commit
0669c1fcd1
@ -27,7 +27,7 @@ from .data import (is_sklearn_available,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels,
|
||||
squad_convert_examples_to_features, SquadFeatures,
|
||||
SquadExample, read_squad_examples)
|
||||
SquadExample)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures
|
||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
from .processors import squad_convert_examples_to_features, SquadExample, read_squad_examples
|
||||
from .processors import squad_convert_examples_to_features, SquadExample
|
||||
|
||||
from .metrics import is_sklearn_available
|
||||
if is_sklearn_available():
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .utils import InputExample, InputFeatures, DataProcessor
|
||||
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, read_squad_examples
|
||||
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample
|
||||
|
||||
|
@ -46,7 +46,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def _new_check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
# if len(doc_spans) == 1:
|
||||
@ -92,7 +91,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
features = []
|
||||
new_features = []
|
||||
for (example_index, example) in enumerate(tqdm(examples)):
|
||||
if is_training:
|
||||
if is_training and not example.is_impossible:
|
||||
# Get start and end position
|
||||
answer_length = len(example.answer_text)
|
||||
start_position = example.start_position
|
||||
@ -105,6 +104,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||
continue
|
||||
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
@ -115,6 +115,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
|
||||
if is_training and not example.is_impossible:
|
||||
tok_start_position = orig_to_tok_index[example.start_position]
|
||||
if example.end_position < len(example.doc_tokens) - 1:
|
||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||
else:
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
|
||||
)
|
||||
|
||||
spans = []
|
||||
|
||||
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||
@ -187,6 +199,34 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
# Set the CLS index to '0'
|
||||
p_mask[cls_index] = 0
|
||||
|
||||
|
||||
span_is_impossible = example.is_impossible
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
if is_training and not span_is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = span["start"]
|
||||
doc_end = span["start"] + span["length"] - 1
|
||||
out_of_span = False
|
||||
|
||||
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
|
||||
if out_of_span:
|
||||
start_position = cls_index
|
||||
end_position = cls_index
|
||||
span_is_impossible = True
|
||||
else:
|
||||
if sequence_a_is_doc:
|
||||
doc_offset = 0
|
||||
else:
|
||||
doc_offset = len(truncated_query) + sequence_added_tokens
|
||||
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
|
||||
new_features.append(NewSquadFeatures(
|
||||
span['input_ids'],
|
||||
span['attention_mask'],
|
||||
@ -199,7 +239,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
paragraph_len=span['paragraph_len'],
|
||||
token_is_max_context=span["token_is_max_context"],
|
||||
tokens=span["tokens"],
|
||||
token_to_orig_map=span["token_to_orig_map"]
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
|
||||
start_position=start_position,
|
||||
end_position=end_position
|
||||
))
|
||||
|
||||
unique_id += 1
|
||||
@ -207,86 +250,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
return new_features
|
||||
|
||||
|
||||
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
"""Read a SQuAD json file into a list of SquadExample."""
|
||||
with open(input_file, "r", encoding='utf-8') as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
def is_whitespace(c):
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
return True
|
||||
return False
|
||||
|
||||
examples = []
|
||||
for entry in input_data:
|
||||
for paragraph in entry["paragraphs"]:
|
||||
paragraph_text = paragraph["context"]
|
||||
doc_tokens = []
|
||||
char_to_word_offset = []
|
||||
prev_is_whitespace = True
|
||||
for c in paragraph_text:
|
||||
if is_whitespace(c):
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
char_to_word_offset.append(len(doc_tokens) - 1)
|
||||
|
||||
for qa in paragraph["qas"]:
|
||||
qas_id = qa["id"]
|
||||
question_text = qa["question"]
|
||||
start_position = None
|
||||
end_position = None
|
||||
orig_answer_text = None
|
||||
is_impossible = False
|
||||
if is_training:
|
||||
if version_2_with_negative:
|
||||
is_impossible = qa["is_impossible"]
|
||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||
raise ValueError(
|
||||
"For training, each question should have exactly 1 answer.")
|
||||
if not is_impossible:
|
||||
answer = qa["answers"][0]
|
||||
orig_answer_text = answer["text"]
|
||||
answer_offset = answer["answer_start"]
|
||||
answer_length = len(orig_answer_text)
|
||||
start_position = char_to_word_offset[answer_offset]
|
||||
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
||||
# Only add answers where the text can be exactly recovered from the
|
||||
# document. If this CAN'T happen it's likely due to weird Unicode
|
||||
# stuff so we will just skip the example.
|
||||
#
|
||||
# Note that this means for training mode, every example is NOT
|
||||
# guaranteed to be preserved.
|
||||
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
||||
cleaned_answer_text = " ".join(
|
||||
whitespace_tokenize(orig_answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'",
|
||||
actual_text, cleaned_answer_text)
|
||||
continue
|
||||
else:
|
||||
start_position = -1
|
||||
end_position = -1
|
||||
orig_answer_text = ""
|
||||
|
||||
example = SquadExample(
|
||||
qas_id=qas_id,
|
||||
question_text=question_text,
|
||||
doc_tokens=doc_tokens,
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=is_impossible)
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
|
||||
class SquadV1Processor(DataProcessor):
|
||||
class SquadProcessor(DataProcessor):
|
||||
"""Processor for the SQuAD data set."""
|
||||
train_file = None
|
||||
dev_file = None
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
@ -301,13 +268,19 @@ class SquadV1Processor(DataProcessor):
|
||||
|
||||
def get_train_examples(self, data_dir, only_first=None):
|
||||
"""See base class."""
|
||||
with open(os.path.join(data_dir, "train-v1.1.json"), "r", encoding='utf-8') as reader:
|
||||
if self.train_file is None:
|
||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||
|
||||
with open(os.path.join(data_dir, self.train_file), "r", encoding='utf-8') as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
return self._create_examples(input_data, "train", only_first)
|
||||
|
||||
def get_dev_examples(self, data_dir, only_first=None):
|
||||
"""See base class."""
|
||||
with open(os.path.join(data_dir, "dev-v1.1.json"), "r", encoding='utf-8') as reader:
|
||||
if self.dev_file is None:
|
||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||
|
||||
with open(os.path.join(data_dir, self.dev_file), "r", encoding='utf-8') as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
return self._create_examples(input_data, "dev", only_first)
|
||||
|
||||
@ -329,7 +302,13 @@ class SquadV1Processor(DataProcessor):
|
||||
question_text = qa["question"]
|
||||
start_position_character = None
|
||||
answer_text = None
|
||||
if is_training:
|
||||
|
||||
if "is_impossible" in qa:
|
||||
is_impossible = qa["is_impossible"]
|
||||
else:
|
||||
is_impossible = False
|
||||
|
||||
if not is_impossible and is_training:
|
||||
if (len(qa["answers"]) != 1):
|
||||
raise ValueError(
|
||||
"For training, each question should have exactly 1 answer.")
|
||||
@ -343,15 +322,25 @@ class SquadV1Processor(DataProcessor):
|
||||
context_text=context_text,
|
||||
answer_text=answer_text,
|
||||
start_position_character=start_position_character,
|
||||
title=title
|
||||
title=title,
|
||||
is_impossible=is_impossible
|
||||
)
|
||||
|
||||
examples.append(example)
|
||||
|
||||
if only_first is not None and len(examples) > only_first:
|
||||
return examples
|
||||
return examples
|
||||
|
||||
|
||||
class SquadV1Processor(SquadProcessor):
|
||||
train_file = "train-v1.1.json"
|
||||
dev_file = "dev-v1.1.json"
|
||||
|
||||
|
||||
class SquadV2Processor(SquadProcessor):
|
||||
train_file = "train-v2.0.json"
|
||||
dev_file = "dev-v2.0.json"
|
||||
|
||||
|
||||
class NewSquadExample(object):
|
||||
"""
|
||||
@ -364,13 +353,16 @@ class NewSquadExample(object):
|
||||
context_text,
|
||||
answer_text,
|
||||
start_position_character,
|
||||
title):
|
||||
title,
|
||||
is_impossible=False):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.context_text = context_text
|
||||
self.answer_text = answer_text
|
||||
self.title = title
|
||||
self.is_impossible = False
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
self.start_position, self.end_position = 0, 0
|
||||
|
||||
doc_tokens = []
|
||||
char_to_word_offset = []
|
||||
@ -392,7 +384,7 @@ class NewSquadExample(object):
|
||||
self.char_to_word_offset = char_to_word_offset
|
||||
|
||||
# Start end end positions only has a value during evaluation.
|
||||
if start_position_character is not None:
|
||||
if start_position_character is not None and not is_impossible:
|
||||
self.start_position = char_to_word_offset[start_position_character]
|
||||
self.end_position = char_to_word_offset[start_position_character + len(answer_text) - 1]
|
||||
|
||||
@ -415,7 +407,10 @@ class NewSquadFeatures(object):
|
||||
paragraph_len,
|
||||
token_is_max_context,
|
||||
tokens,
|
||||
token_to_orig_map
|
||||
token_to_orig_map,
|
||||
|
||||
start_position,
|
||||
end_position
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
@ -430,6 +425,9 @@ class NewSquadFeatures(object):
|
||||
self.tokens = tokens
|
||||
self.token_to_orig_map = token_to_orig_map
|
||||
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
|
||||
class SquadExample(object):
|
||||
"""
|
||||
A single training/test example for the Squad dataset.
|
||||
|
Loading…
Reference in New Issue
Block a user