New -> normal

This commit is contained in:
Lysandre 2019-11-28 17:43:47 -05:00
parent 0b84b9fd8a
commit 1e9ac5a7cf

View File

@ -217,7 +217,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position = tok_end_position - doc_start + doc_offset
features.append(NewSquadFeatures(
features.append(SquadFeatures(
span['input_ids'],
span['attention_mask'],
span['token_type_ids'],
@ -246,7 +246,7 @@ class SquadProcessor(DataProcessor):
dev_file = None
def get_example_from_tensor_dict(self, tensor_dict):
return NewSquadExample(
return SquadExample(
tensor_dict['id'].numpy().decode("utf-8"),
tensor_dict['question'].numpy().decode('utf-8'),
tensor_dict['context'].numpy().decode('utf-8'),
@ -314,7 +314,7 @@ class SquadProcessor(DataProcessor):
answer_text = answer['text']
start_position_character = answer['answer_start']
example = NewSquadExample(
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
context_text=context_text,
@ -340,7 +340,7 @@ class SquadV2Processor(SquadProcessor):
dev_file = "dev-v2.0.json"
class NewSquadExample(object):
class SquadExample(object):
"""
A single training/test example for the Squad dataset, as loaded from disk.
"""
@ -387,7 +387,7 @@ class NewSquadExample(object):
self.end_position = char_to_word_offset[start_position_character + len(answer_text) - 1]
class NewSquadFeatures(object):
class SquadFeatures(object):
"""
Single squad example features to be fed to a model.
Those features are model-specific.
@ -425,99 +425,3 @@ class NewSquadFeatures(object):
self.start_position = start_position
self.end_position = end_position
class SquadExample(object):
"""
A single training/test example for the Squad dataset.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=None):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (self.qas_id)
s += ", question_text: %s" % (
self.question_text)
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.end_position:
s += ", end_position: %d" % (self.end_position)
if self.is_impossible:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class SquadFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
cls_index,
p_mask,
paragraph_len,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.cls_index = cls_index
self.p_mask = p_mask
self.paragraph_len = paragraph_len
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __eq__(self, other):
print(self.example_index == other.example_index)
print(self.input_ids == other.input_ids)
print(self.input_mask == other.attention_mask)
print(self.p_mask == other.p_mask)
print(self.paragraph_len == other.paragraph_len)
print(self.segment_ids == other.token_type_ids)
print(self.token_is_max_context == other.token_is_max_context)
print(self.token_to_orig_map == other.token_to_orig_map)
print(self.tokens == other.tokens)
return self.example_index == other.example_index and \
self.input_ids == other.input_ids and \
self.input_mask == other.attention_mask and \
self.p_mask == other.p_mask and \
self.paragraph_len == other.paragraph_len and \
self.segment_ids == other.token_type_ids and \
self.token_is_max_context == other.token_is_max_context and \
self.token_to_orig_map == other.token_to_orig_map and \
self.tokens == other.tokens